diff --git a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_builder_types.hpp b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_builder_types.hpp index 95abffe5120..2db0d13abf2 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_builder_types.hpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_builder_types.hpp @@ -4,13 +4,12 @@ #include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp" namespace bb::eccvm { - -static constexpr size_t NUM_SCALAR_BITS = 128; -static constexpr size_t WNAF_SLICE_BITS = 4; -static constexpr size_t NUM_WNAF_SLICES = (NUM_SCALAR_BITS + WNAF_SLICE_BITS - 1) / WNAF_SLICE_BITS; -static constexpr uint64_t WNAF_MASK = static_cast((1ULL << WNAF_SLICE_BITS) - 1ULL); -static constexpr size_t POINT_TABLE_SIZE = 1ULL << (WNAF_SLICE_BITS); -static constexpr size_t WNAF_SLICES_PER_ROW = 4; +static constexpr size_t NUM_SCALAR_BITS = 128; // The length of scalars handled by the ECCVVM +static constexpr size_t NUM_WNAF_DIGIT_BITS = 4; // Scalars are decompose into base 16 in wNAF form +static constexpr size_t NUM_WNAF_DIGITS_PER_SCALAR = NUM_SCALAR_BITS / NUM_WNAF_DIGIT_BITS; // 32 +static constexpr uint64_t WNAF_MASK = static_cast((1ULL << NUM_WNAF_DIGIT_BITS) - 1ULL); +static constexpr size_t POINT_TABLE_SIZE = 1ULL << (NUM_WNAF_DIGIT_BITS); +static constexpr size_t WNAF_DIGITS_PER_ROW = 4; static constexpr size_t ADDITIONS_PER_ROW = 4; template struct VMOperation { @@ -39,7 +38,7 @@ template struct ScalarMul { uint32_t pc; uint256_t scalar; typename CycleGroup::affine_element base_point; - std::array wnaf_slices; + std::array wnaf_digits; bool wnaf_skew; // size bumped by 1 to record base_point.dbl() std::array precomputed_table; diff --git a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.hpp b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.hpp index b295133b12a..7f49af86030 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.hpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.hpp @@ -24,11 +24,11 @@ class ECCVMCircuitBuilder { using AffineElement = typename CycleGroup::affine_element; static constexpr size_t NUM_SCALAR_BITS = bb::eccvm::NUM_SCALAR_BITS; - static constexpr size_t WNAF_SLICE_BITS = bb::eccvm::WNAF_SLICE_BITS; - static constexpr size_t NUM_WNAF_SLICES = bb::eccvm::NUM_WNAF_SLICES; + static constexpr size_t NUM_WNAF_DIGIT_BITS = bb::eccvm::NUM_WNAF_DIGIT_BITS; + static constexpr size_t NUM_WNAF_DIGITS_PER_SCALAR = bb::eccvm::NUM_WNAF_DIGITS_PER_SCALAR; static constexpr uint64_t WNAF_MASK = bb::eccvm::WNAF_MASK; static constexpr size_t POINT_TABLE_SIZE = bb::eccvm::POINT_TABLE_SIZE; - static constexpr size_t WNAF_SLICES_PER_ROW = bb::eccvm::WNAF_SLICES_PER_ROW; + static constexpr size_t WNAF_DIGITS_PER_ROW = bb::eccvm::WNAF_DIGITS_PER_ROW; static constexpr size_t ADDITIONS_PER_ROW = bb::eccvm::ADDITIONS_PER_ROW; using MSM = bb::eccvm::MSM; @@ -50,7 +50,8 @@ class ECCVMCircuitBuilder { /** * For input point [P], return { -15[P], -13[P], ..., -[P], [P], ..., 13[P], 15[P] } */ - const auto compute_precomputed_table = [](const AffineElement& base_point) { + const auto compute_precomputed_table = + [](const AffineElement& base_point) -> std::array { const auto d2 = Element(base_point).dbl(); std::array table; table[POINT_TABLE_SIZE] = d2; // need this for later @@ -69,10 +70,10 @@ class ECCVMCircuitBuilder { } return result; }; - const auto compute_wnaf_slices = [](uint256_t scalar) { - std::array output; + const auto compute_wnaf_digits = [](uint256_t scalar) -> std::array { + std::array output; int previous_slice = 0; - for (size_t i = 0; i < NUM_WNAF_SLICES; ++i) { + for (size_t i = 0; i < NUM_WNAF_DIGITS_PER_SCALAR; ++i) { // slice the scalar into 4-bit chunks, starting with the least significant bits uint64_t raw_slice = static_cast(scalar) & WNAF_MASK; @@ -86,19 +87,19 @@ class ECCVMCircuitBuilder { } else if (is_even) { // for other slices, if it's even, we add 1 to the slice value // and subtract 16 from the previous slice to preserve the total scalar sum - static constexpr int borrow_constant = static_cast(1ULL << WNAF_SLICE_BITS); + static constexpr int borrow_constant = static_cast(1ULL << NUM_WNAF_DIGIT_BITS); previous_slice -= borrow_constant; wnaf_slice += 1; } if (i > 0) { const size_t idx = i - 1; - output[NUM_WNAF_SLICES - idx - 1] = previous_slice; + output[NUM_WNAF_DIGITS_PER_SCALAR - idx - 1] = previous_slice; } previous_slice = wnaf_slice; // downshift raw_slice by 4 bits - scalar = scalar >> WNAF_SLICE_BITS; + scalar = scalar >> NUM_WNAF_DIGIT_BITS; } ASSERT(scalar == 0); @@ -108,8 +109,6 @@ class ECCVMCircuitBuilder { return output; }; - // a vector of MSMs = a vector of a vector of scalar muls - // each mul size_t msm_count = 0; size_t active_mul_count = 0; std::vector msm_opqueue_index; @@ -118,6 +117,7 @@ class ECCVMCircuitBuilder { const auto& raw_ops = op_queue->get_raw_ops(); size_t op_idx = 0; + // populate opqueue and mul indices for (const auto& op : raw_ops) { if (op.mul) { if (op.z1 != 0 || op.z2 != 0) { @@ -142,39 +142,38 @@ class ECCVMCircuitBuilder { msm_sizes.push_back(active_mul_count); msm_count++; } - std::vector msms_test(msm_count); + std::vector result(msm_count); for (size_t i = 0; i < msm_count; ++i) { - auto& msm = msms_test[i]; + auto& msm = result[i]; msm.resize(msm_sizes[i]); } run_loop_in_parallel(msm_opqueue_index.size(), [&](size_t start, size_t end) { for (size_t i = start; i < end; i++) { - const size_t opqueue_index = msm_opqueue_index[i]; - const auto& op = raw_ops[opqueue_index]; + const auto& op = raw_ops[msm_opqueue_index[i]]; auto [msm_index, mul_index] = msm_mul_index[i]; if (op.z1 != 0) { - ASSERT(msms_test.size() > msm_index); - ASSERT(msms_test[msm_index].size() > mul_index); - msms_test[msm_index][mul_index] = (ScalarMul{ + ASSERT(result.size() > msm_index); + ASSERT(result[msm_index].size() > mul_index); + result[msm_index][mul_index] = (ScalarMul{ .pc = 0, .scalar = op.z1, .base_point = op.base_point, - .wnaf_slices = compute_wnaf_slices(op.z1), + .wnaf_digits = compute_wnaf_digits(op.z1), .wnaf_skew = (op.z1 & 1) == 0, .precomputed_table = compute_precomputed_table(op.base_point), }); mul_index++; } if (op.z2 != 0) { - ASSERT(msms_test.size() > msm_index); - ASSERT(msms_test[msm_index].size() > mul_index); + ASSERT(result.size() > msm_index); + ASSERT(result[msm_index].size() > mul_index); auto endo_point = AffineElement{ op.base_point.x * FF::cube_root_of_unity(), -op.base_point.y }; - msms_test[msm_index][mul_index] = (ScalarMul{ + result[msm_index][mul_index] = (ScalarMul{ .pc = 0, .scalar = op.z2, .base_point = endo_point, - .wnaf_slices = compute_wnaf_slices(op.z2), + .wnaf_digits = compute_wnaf_digits(op.z2), .wnaf_skew = (op.z2 & 1) == 0, .precomputed_table = compute_precomputed_table(endo_point), }); @@ -191,7 +190,7 @@ class ECCVMCircuitBuilder { // sumcheck relations that involve pc (if we did the other way around, starting at 1 and ending at num_muls, // we create a discontinuity in pc values between the last transcript row and the following empty row) uint32_t pc = num_muls; - for (auto& msm : msms_test) { + for (auto& msm : result) { for (auto& mul : msm) { mul.pc = pc; pc--; @@ -199,7 +198,7 @@ class ECCVMCircuitBuilder { } ASSERT(pc == 0); - return msms_test; + return result; } static std::vector get_flattened_scalar_muls(const std::vector& msms) diff --git a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp index e1828ca8fe4..759353edb0a 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/eccvm_flavor.hpp @@ -34,6 +34,7 @@ class ECCVMFlavor { using CommitmentKey = bb::CommitmentKey; using VerifierCommitmentKey = bb::VerifierCommitmentKey; using RelationSeparator = FF; + using MSM = bb::eccvm::MSM; static constexpr size_t NUM_WIRES = 74; @@ -358,6 +359,7 @@ class ECCVMFlavor { ProverPolynomials& operator=(ProverPolynomials&& o) noexcept = default; ~ProverPolynomials() = default; [[nodiscard]] size_t get_polynomial_size() const { return this->lagrange_first.size(); } + /** * @brief Returns the evaluations of all prover polynomials at one point on the boolean hypercube, which * represents one row in the execution trace. @@ -460,33 +462,28 @@ class ECCVMFlavor { */ ProverPolynomials(const CircuitBuilder& builder) { - const auto msms = builder.get_msms(); - const auto flattened_muls = builder.get_flattened_scalar_muls(msms); - - std::array, 2> point_table_read_counts; - const auto transcript_state = ECCVMTranscriptBuilder::compute_transcript_state( - builder.op_queue->get_raw_ops(), builder.get_number_of_muls()); - const auto precompute_table_state = ECCVMPrecomputedTablesBuilder::compute_precompute_state(flattened_muls); - const auto msm_state = ECCVMMSMMBuilder::compute_msm_state( - msms, point_table_read_counts, builder.get_number_of_muls(), builder.op_queue->get_num_msm_rows()); - - const size_t msm_size = msm_state.size(); - const size_t transcript_size = transcript_state.size(); - const size_t precompute_table_size = precompute_table_state.size(); - - const size_t num_rows = std::max(precompute_table_size, std::max(msm_size, transcript_size)); - - const auto num_rows_log2 = static_cast(numeric::get_msb64(num_rows)); - size_t num_rows_pow2 = 1UL << (num_rows_log2 + (1UL << num_rows_log2 == num_rows ? 0 : 1)); + // compute rows for the three different sections of the ECCVM execution trace + const auto transcript_rows = + ECCVMTranscriptBuilder::compute_rows(builder.op_queue->get_raw_ops(), builder.get_number_of_muls()); + const std::vector msms = builder.get_msms(); + const auto point_table_rows = + ECCVMPointTablePrecomputationBuilder::compute_rows(CircuitBuilder::get_flattened_scalar_muls(msms)); + const auto [msm_rows, point_table_read_counts] = ECCVMMSMMBuilder::compute_rows( + msms, builder.get_number_of_muls(), builder.op_queue->get_num_msm_rows()); + + const size_t num_rows = std::max({ point_table_rows.size(), msm_rows.size(), transcript_rows.size() }); + const auto log_num_rows = static_cast(numeric::get_msb64(num_rows)); + const size_t dyadic_num_rows = 1UL << (log_num_rows + (1UL << log_num_rows == num_rows ? 0 : 1)); + + // allocate polynomials; define lagrange and lookup read count polynomials for (auto& poly : get_all()) { - poly = Polynomial(num_rows_pow2); + poly = Polynomial(dyadic_num_rows); } lagrange_first[0] = 1; lagrange_second[1] = 1; lagrange_last[lagrange_last.size() - 1] = 1; - for (size_t i = 0; i < point_table_read_counts[0].size(); ++i) { - // Explanation of off-by-one offset + // Explanation of off-by-one offset: // When computing the WNAF slice for a point at point counter value `pc` and a round index `round`, the // row number that computes the slice can be derived. This row number is then mapped to the index of // `lookup_read_counts`. We do this mapping in `ecc_msm_relation`. We are off-by-one because we add an @@ -495,106 +492,109 @@ class ECCVMFlavor { lookup_read_counts_0[i + 1] = point_table_read_counts[0][i]; lookup_read_counts_1[i + 1] = point_table_read_counts[1][i]; } - run_loop_in_parallel(transcript_state.size(), [&](size_t start, size_t end) { + + // compute polynomials for transcript columns + run_loop_in_parallel(transcript_rows.size(), [&](size_t start, size_t end) { for (size_t i = start; i < end; i++) { - transcript_accumulator_empty[i] = transcript_state[i].accumulator_empty; - transcript_add[i] = transcript_state[i].q_add; - transcript_mul[i] = transcript_state[i].q_mul; - transcript_eq[i] = transcript_state[i].q_eq; - transcript_reset_accumulator[i] = transcript_state[i].q_reset_accumulator; - transcript_msm_transition[i] = transcript_state[i].msm_transition; - transcript_pc[i] = transcript_state[i].pc; - transcript_msm_count[i] = transcript_state[i].msm_count; - transcript_Px[i] = transcript_state[i].base_x; - transcript_Py[i] = transcript_state[i].base_y; - transcript_z1[i] = transcript_state[i].z1; - transcript_z2[i] = transcript_state[i].z2; - transcript_z1zero[i] = transcript_state[i].z1_zero; - transcript_z2zero[i] = transcript_state[i].z2_zero; - transcript_op[i] = transcript_state[i].opcode; - transcript_accumulator_x[i] = transcript_state[i].accumulator_x; - transcript_accumulator_y[i] = transcript_state[i].accumulator_y; - transcript_msm_x[i] = transcript_state[i].msm_output_x; - transcript_msm_y[i] = transcript_state[i].msm_output_y; - transcript_collision_check[i] = transcript_state[i].collision_check; + transcript_accumulator_empty[i] = transcript_rows[i].accumulator_empty; + transcript_add[i] = transcript_rows[i].q_add; + transcript_mul[i] = transcript_rows[i].q_mul; + transcript_eq[i] = transcript_rows[i].q_eq; + transcript_reset_accumulator[i] = transcript_rows[i].q_reset_accumulator; + transcript_msm_transition[i] = transcript_rows[i].msm_transition; + transcript_pc[i] = transcript_rows[i].pc; + transcript_msm_count[i] = transcript_rows[i].msm_count; + transcript_Px[i] = transcript_rows[i].base_x; + transcript_Py[i] = transcript_rows[i].base_y; + transcript_z1[i] = transcript_rows[i].z1; + transcript_z2[i] = transcript_rows[i].z2; + transcript_z1zero[i] = transcript_rows[i].z1_zero; + transcript_z2zero[i] = transcript_rows[i].z2_zero; + transcript_op[i] = transcript_rows[i].opcode; + transcript_accumulator_x[i] = transcript_rows[i].accumulator_x; + transcript_accumulator_y[i] = transcript_rows[i].accumulator_y; + transcript_msm_x[i] = transcript_rows[i].msm_output_x; + transcript_msm_y[i] = transcript_rows[i].msm_output_y; + transcript_collision_check[i] = transcript_rows[i].collision_check; } }); // TODO(@zac-williamson) if final opcode resets accumulator, all subsequent "is_accumulator_empty" row // values must be 1. Ideally we find a way to tweak this so that empty rows that do nothing have column // values that are all zero (issue #2217) - if (transcript_state[transcript_state.size() - 1].accumulator_empty == 1) { - for (size_t i = transcript_state.size(); i < num_rows_pow2; ++i) { + if (transcript_rows[transcript_rows.size() - 1].accumulator_empty) { + for (size_t i = transcript_rows.size(); i < dyadic_num_rows; ++i) { transcript_accumulator_empty[i] = 1; } } - run_loop_in_parallel(precompute_table_state.size(), [&](size_t start, size_t end) { + + // compute polynomials for point table columns + run_loop_in_parallel(point_table_rows.size(), [&](size_t start, size_t end) { for (size_t i = start; i < end; i++) { // first row is always an empty row (to accommodate shifted polynomials which must have 0 as 1st - // coefficient). All other rows in the precompute_table_state represent active wnaf gates (i.e. + // coefficient). All other rows in the point_table_rows represent active wnaf gates (i.e. // precompute_select = 1) precompute_select[i] = (i != 0) ? 1 : 0; - precompute_pc[i] = precompute_table_state[i].pc; - precompute_point_transition[i] = static_cast(precompute_table_state[i].point_transition); - precompute_round[i] = precompute_table_state[i].round; - precompute_scalar_sum[i] = precompute_table_state[i].scalar_sum; - - precompute_s1hi[i] = precompute_table_state[i].s1; - precompute_s1lo[i] = precompute_table_state[i].s2; - precompute_s2hi[i] = precompute_table_state[i].s3; - precompute_s2lo[i] = precompute_table_state[i].s4; - precompute_s3hi[i] = precompute_table_state[i].s5; - precompute_s3lo[i] = precompute_table_state[i].s6; - precompute_s4hi[i] = precompute_table_state[i].s7; - precompute_s4lo[i] = precompute_table_state[i].s8; + precompute_pc[i] = point_table_rows[i].pc; + precompute_point_transition[i] = static_cast(point_table_rows[i].point_transition); + precompute_round[i] = point_table_rows[i].round; + precompute_scalar_sum[i] = point_table_rows[i].scalar_sum; + precompute_s1hi[i] = point_table_rows[i].s1; + precompute_s1lo[i] = point_table_rows[i].s2; + precompute_s2hi[i] = point_table_rows[i].s3; + precompute_s2lo[i] = point_table_rows[i].s4; + precompute_s3hi[i] = point_table_rows[i].s5; + precompute_s3lo[i] = point_table_rows[i].s6; + precompute_s4hi[i] = point_table_rows[i].s7; + precompute_s4lo[i] = point_table_rows[i].s8; // If skew is active (i.e. we need to subtract a base point from the msm result), // write `7` into rows.precompute_skew. `7`, in binary representation, equals `-1` when converted // into WNAF form - precompute_skew[i] = precompute_table_state[i].skew ? 7 : 0; - - precompute_dx[i] = precompute_table_state[i].precompute_double.x; - precompute_dy[i] = precompute_table_state[i].precompute_double.y; - precompute_tx[i] = precompute_table_state[i].precompute_accumulator.x; - precompute_ty[i] = precompute_table_state[i].precompute_accumulator.y; + precompute_skew[i] = point_table_rows[i].skew ? 7 : 0; + precompute_dx[i] = point_table_rows[i].precompute_double.x; + precompute_dy[i] = point_table_rows[i].precompute_double.y; + precompute_tx[i] = point_table_rows[i].precompute_accumulator.x; + precompute_ty[i] = point_table_rows[i].precompute_accumulator.y; } }); - run_loop_in_parallel(msm_state.size(), [&](size_t start, size_t end) { + // compute polynomials for the msm columns + run_loop_in_parallel(msm_rows.size(), [&](size_t start, size_t end) { for (size_t i = start; i < end; i++) { - msm_transition[i] = static_cast(msm_state[i].msm_transition); - msm_add[i] = static_cast(msm_state[i].q_add); - msm_double[i] = static_cast(msm_state[i].q_double); - msm_skew[i] = static_cast(msm_state[i].q_skew); - msm_accumulator_x[i] = msm_state[i].accumulator_x; - msm_accumulator_y[i] = msm_state[i].accumulator_y; - msm_pc[i] = msm_state[i].pc; - msm_size_of_msm[i] = msm_state[i].msm_size; - msm_count[i] = msm_state[i].msm_count; - msm_round[i] = msm_state[i].msm_round; - msm_add1[i] = static_cast(msm_state[i].add_state[0].add); - msm_add2[i] = static_cast(msm_state[i].add_state[1].add); - msm_add3[i] = static_cast(msm_state[i].add_state[2].add); - msm_add4[i] = static_cast(msm_state[i].add_state[3].add); - msm_x1[i] = msm_state[i].add_state[0].point.x; - msm_y1[i] = msm_state[i].add_state[0].point.y; - msm_x2[i] = msm_state[i].add_state[1].point.x; - msm_y2[i] = msm_state[i].add_state[1].point.y; - msm_x3[i] = msm_state[i].add_state[2].point.x; - msm_y3[i] = msm_state[i].add_state[2].point.y; - msm_x4[i] = msm_state[i].add_state[3].point.x; - msm_y4[i] = msm_state[i].add_state[3].point.y; - msm_collision_x1[i] = msm_state[i].add_state[0].collision_inverse; - msm_collision_x2[i] = msm_state[i].add_state[1].collision_inverse; - msm_collision_x3[i] = msm_state[i].add_state[2].collision_inverse; - msm_collision_x4[i] = msm_state[i].add_state[3].collision_inverse; - msm_lambda1[i] = msm_state[i].add_state[0].lambda; - msm_lambda2[i] = msm_state[i].add_state[1].lambda; - msm_lambda3[i] = msm_state[i].add_state[2].lambda; - msm_lambda4[i] = msm_state[i].add_state[3].lambda; - msm_slice1[i] = msm_state[i].add_state[0].slice; - msm_slice2[i] = msm_state[i].add_state[1].slice; - msm_slice3[i] = msm_state[i].add_state[2].slice; - msm_slice4[i] = msm_state[i].add_state[3].slice; + msm_transition[i] = static_cast(msm_rows[i].msm_transition); + msm_add[i] = static_cast(msm_rows[i].q_add); + msm_double[i] = static_cast(msm_rows[i].q_double); + msm_skew[i] = static_cast(msm_rows[i].q_skew); + msm_accumulator_x[i] = msm_rows[i].accumulator_x; + msm_accumulator_y[i] = msm_rows[i].accumulator_y; + msm_pc[i] = msm_rows[i].pc; + msm_size_of_msm[i] = msm_rows[i].msm_size; + msm_count[i] = msm_rows[i].msm_count; + msm_round[i] = msm_rows[i].msm_round; + msm_add1[i] = static_cast(msm_rows[i].add_state[0].add); + msm_add2[i] = static_cast(msm_rows[i].add_state[1].add); + msm_add3[i] = static_cast(msm_rows[i].add_state[2].add); + msm_add4[i] = static_cast(msm_rows[i].add_state[3].add); + msm_x1[i] = msm_rows[i].add_state[0].point.x; + msm_y1[i] = msm_rows[i].add_state[0].point.y; + msm_x2[i] = msm_rows[i].add_state[1].point.x; + msm_y2[i] = msm_rows[i].add_state[1].point.y; + msm_x3[i] = msm_rows[i].add_state[2].point.x; + msm_y3[i] = msm_rows[i].add_state[2].point.y; + msm_x4[i] = msm_rows[i].add_state[3].point.x; + msm_y4[i] = msm_rows[i].add_state[3].point.y; + msm_collision_x1[i] = msm_rows[i].add_state[0].collision_inverse; + msm_collision_x2[i] = msm_rows[i].add_state[1].collision_inverse; + msm_collision_x3[i] = msm_rows[i].add_state[2].collision_inverse; + msm_collision_x4[i] = msm_rows[i].add_state[3].collision_inverse; + msm_lambda1[i] = msm_rows[i].add_state[0].lambda; + msm_lambda2[i] = msm_rows[i].add_state[1].lambda; + msm_lambda3[i] = msm_rows[i].add_state[2].lambda; + msm_lambda4[i] = msm_rows[i].add_state[3].lambda; + msm_slice1[i] = msm_rows[i].add_state[0].slice; + msm_slice2[i] = msm_rows[i].add_state[1].slice; + msm_slice3[i] = msm_rows[i].add_state[2].slice; + msm_slice4[i] = msm_rows[i].add_state[3].slice; } }); this->set_shifted(); diff --git a/barretenberg/cpp/src/barretenberg/eccvm/msm_builder.hpp b/barretenberg/cpp/src/barretenberg/eccvm/msm_builder.hpp index 5572bab54ee..69f4871eb91 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/msm_builder.hpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/msm_builder.hpp @@ -13,13 +13,15 @@ class ECCVMMSMMBuilder { using FF = curve::Grumpkin::ScalarField; using Element = typename CycleGroup::element; using AffineElement = typename CycleGroup::affine_element; + using MSM = bb::eccvm::MSM; static constexpr size_t ADDITIONS_PER_ROW = bb::eccvm::ADDITIONS_PER_ROW; - static constexpr size_t NUM_SCALAR_BITS = bb::eccvm::NUM_SCALAR_BITS; - static constexpr size_t WNAF_SLICE_BITS = bb::eccvm::WNAF_SLICE_BITS; + static constexpr size_t NUM_WNAF_DIGITS_PER_SCALAR = bb::eccvm::NUM_WNAF_DIGITS_PER_SCALAR; - struct alignas(64) MSMState { + struct alignas(64) MSMRow { + // counter over all half-length scalar muls used to compute the required MSMs uint32_t pc = 0; + // the number of points that will be scaled and summed uint32_t msm_size = 0; uint32_t msm_count = 0; uint32_t msm_round = 0; @@ -43,138 +45,138 @@ class ECCVMMSMMBuilder { FF accumulator_y = 0; }; - struct alignas(64) MSMRowTranscript { - std::array lambda_numerator; - std::array lambda_denominator; - Element accumulator_in; - Element accumulator_out; - }; - - struct alignas(64) AdditionTrace { - Element p1; - Element p2; - Element p3; - bool predicate; - bool is_double; - }; - /** * @brief Computes the row values for the Straus MSM columns of the ECCVM. * * For a detailed description of the Straus algorithm and its relation to the ECCVM, please see * https://hackmd.io/@aztec-network/rJ5xhuCsn * - * @param msms - * @param point_table_read_counts - * @param total_number_of_muls - * @return std::vector + * @param msms A vector of vectors of ScalarMuls. + * @param point_table_read_counts Table of read counts to be populated. + * @param total_number_of_muls A mul op in the OpQueue adds up to two muls, one for each nonzero z_i (i=1,2). + * @param num_msm_rows + * @return std::vector */ - static std::vector compute_msm_state(const std::vector>& msms, - std::array, 2>& point_table_read_counts, - const uint32_t total_number_of_muls, - const size_t num_msm_rows) + static std::tuple, std::array, 2>> compute_rows( + const std::vector& msms, const uint32_t total_number_of_muls, const size_t num_msm_rows) { - // N.B. the following comments refer to a "point lookup table" frequently. - // To perform a scalar multiplicaiton of a point [P] by a scalar x, we compute multiples of [P] and store in a - // table: specifically: -15[P], -13[P], ..., -3[P], -[P], [P], 3[P], ..., 15[P] when we define our point lookup - // table, we have 2 write columns and 4 read columns when we perform a read on a given row, we need to increment - // the read count on the respective write column by 1 we can define the following struture: 1st write column = - // positive 2nd write column = negative the row number is a function of pc and slice value row = pc_delta * - // rows_per_point_table + some function of the slice value pc_delta = total_number_of_muls - pc - // std::vector point_table_read_counts; - const size_t table_rows = static_cast(total_number_of_muls) * 8; - point_table_read_counts[0].reserve(table_rows); - point_table_read_counts[1].reserve(table_rows); - for (size_t i = 0; i < table_rows; ++i) { + // To perform a scalar multiplication of a point P by a scalar x, we precompute a table of points + // -15P, -13P, ..., -3P, -P, P, 3P, ..., 15P + // When we perform a scalar multiplication, we decompose x into base-16 wNAF digits then look these precomputed + // values up with digit-by-digit. We record read counts in a table with the following structure: + // 1st write column = positive wNAF digits + // 2nd write column = negative wNAF digits + // the row number is a function of pc and wnaf digit: + // point_idx = total_number_of_muls - pc + // row = point_idx * rows_per_point_table + (some function of the slice value) + // + // Illustration: + // Block Structure Table structure: + // | 0 | 1 | | Block_{0} | <-- pc = total_number_of_muls + // | - | - | | Block_{1} | <-- pc = total_number_of_muls-(num muls in msm 0) + // 1 | # | # | -1 | ... | ... + // 3 | # | # | -3 | Block_{total_number_of_muls-1} | <-- pc = num muls in last msm + // 5 | # | # | -5 + // 7 | # | # | -7 + // 9 | # | # | -9 + // 11 | # | # | -11 + // 13 | # | # | -13 + // 15 | # | # | -15 + + const size_t num_rows_in_read_counts_table = + static_cast(total_number_of_muls) * (eccvm::POINT_TABLE_SIZE >> 1); + std::array, 2> point_table_read_counts; + point_table_read_counts[0].reserve(num_rows_in_read_counts_table); + point_table_read_counts[1].reserve(num_rows_in_read_counts_table); + for (size_t i = 0; i < num_rows_in_read_counts_table; ++i) { point_table_read_counts[0].emplace_back(0); point_table_read_counts[1].emplace_back(0); } - const auto update_read_counts = [&](const size_t pc, const int slice) { - // When we compute our wnaf/point tables, we start with the point with the largest pc value. - // i.e. if we are reading a slice for point with a point counter value `pc`, - // its position in the wnaf/point table (relative to other points) will be `total_number_of_muls - pc` - const size_t pc_delta = total_number_of_muls - pc; - const size_t pc_offset = pc_delta * 8; - bool slice_negative = slice < 0; - const int slice_row = (slice + 15) / 2; - - const size_t column_index = slice_negative ? 1 : 0; + const auto update_read_count = [&point_table_read_counts](const size_t point_idx, const int slice) { /** - * When computing `point_table_read_counts`, we need the *table index* that a given point belongs to. - * the slice value is in *compressed* windowed-non-adjacent-form format: - * A non-compressed WNAF slice is in the range: `-15, -13, ..., 15` - * In compressed form, tney become `0, ..., 15` + * The wNAF digits for base 16 lie in the range -15, -13, ..., 13, 15. * The *point table* format is the following: - * (for positive point table) T[0] = P, T[1] = PT, ..., T[7] = 15P + * (for positive point table) T[0] = P, T[1] = 3P, ..., T[7] = 15P * (for negative point table) T[0] = -P, T[1] = -3P, ..., T[15] = -15P * i.e. if the slice value is negative, we can use the compressed WNAF directly as the table index - * if the slice value is positive, we must take `15 - compressedWNAF` to get the table index + * if the slice value is positive, we must take 15 - (compressed wNAF) to get the table index */ - if (slice_negative) { - point_table_read_counts[column_index][pc_offset + static_cast(slice_row)]++; + const size_t row_index_offset = point_idx * 8; + const bool digit_is_negative = slice < 0; + const auto relative_row_idx = static_cast((slice + 15) / 2); + const size_t column_index = digit_is_negative ? 1 : 0; + + if (digit_is_negative) { + point_table_read_counts[column_index][row_index_offset + relative_row_idx]++; } else { - point_table_read_counts[column_index][pc_offset + 15 - static_cast(slice_row)]++; + point_table_read_counts[column_index][row_index_offset + 15 - relative_row_idx]++; } }; // compute which row index each multiscalar multiplication will start at. - // also compute the program counter index that each multiscalar multiplication will start at. - // we use this information to populate the MSM row data across multiple threads - std::vector msm_row_indices; - std::vector pc_indices; - msm_row_indices.reserve(msms.size() + 1); - pc_indices.reserve(msms.size() + 1); - - msm_row_indices.push_back(1); - pc_indices.push_back(total_number_of_muls); + std::vector msm_row_counts; + msm_row_counts.reserve(msms.size() + 1); + msm_row_counts.push_back(1); + // compute the program counter (i.e. the index among all single scalar muls) that each multiscalar + // multiplication will start at. + std::vector pc_values; + pc_values.reserve(msms.size() + 1); + pc_values.push_back(total_number_of_muls); for (const auto& msm : msms) { - const size_t rows = ECCOpQueue::get_msm_row_count_for_single_msm(msm.size()); - msm_row_indices.push_back(msm_row_indices.back() + rows); - pc_indices.push_back(pc_indices.back() - msm.size()); + const size_t num_rows_required = ECCOpQueue::num_eccvm_msm_rows(msm.size()); + msm_row_counts.push_back(msm_row_counts.back() + num_rows_required); + pc_values.push_back(pc_values.back() - msm.size()); } + ASSERT(pc_values.back() == 0); - static constexpr size_t num_rounds = NUM_SCALAR_BITS / WNAF_SLICE_BITS; - std::vector msm_state(num_msm_rows); - // start with empty row (shiftable polynomials must have 0 as first coefficient) - msm_state[0] = (MSMState{}); + // compute the MSM rows + std::vector msm_rows(num_msm_rows); + // start with empty row (shiftable polynomials must have 0 as first coefficient) + msm_rows[0] = (MSMRow{}); // compute "read counts" so that we can determine the number of times entries in our log-derivative lookup // tables are called. - // Note: this part is single-threaded. THe amount of compute is low, however, so this is likely not a big + // Note: this part is single-threaded. The amount of compute is low, however, so this is likely not a big // concern. - for (size_t i = 0; i < msms.size(); ++i) { - - for (size_t j = 0; j < num_rounds; ++j) { - uint32_t pc = static_cast(pc_indices[i]); - const auto& msm = msms[i]; + for (size_t msm_idx = 0; msm_idx < msms.size(); ++msm_idx) { + for (size_t digit_idx = 0; digit_idx < NUM_WNAF_DIGITS_PER_SCALAR; ++digit_idx) { + auto pc = static_cast(pc_values[msm_idx]); + const auto& msm = msms[msm_idx]; const size_t msm_size = msm.size(); - const size_t rows_per_round = - (msm_size / ADDITIONS_PER_ROW) + (msm_size % ADDITIONS_PER_ROW != 0 ? 1 : 0); - - for (size_t k = 0; k < rows_per_round; ++k) { - const size_t points_per_row = - (k + 1) * ADDITIONS_PER_ROW > msm_size ? msm_size % ADDITIONS_PER_ROW : ADDITIONS_PER_ROW; - const size_t idx = k * ADDITIONS_PER_ROW; - for (size_t m = 0; m < ADDITIONS_PER_ROW; ++m) { - bool add = points_per_row > m; + const size_t num_rows_per_digit = + (msm_size / ADDITIONS_PER_ROW) + ((msm_size % ADDITIONS_PER_ROW != 0) ? 1 : 0); + + for (size_t relative_row_idx = 0; relative_row_idx < num_rows_per_digit; ++relative_row_idx) { + const size_t num_points_in_row = (relative_row_idx + 1) * ADDITIONS_PER_ROW > msm_size + ? (msm_size % ADDITIONS_PER_ROW) + : ADDITIONS_PER_ROW; + const size_t offset = relative_row_idx * ADDITIONS_PER_ROW; + for (size_t relative_point_idx = 0; relative_point_idx < ADDITIONS_PER_ROW; ++relative_point_idx) { + const size_t point_idx = offset + relative_point_idx; + const bool add = num_points_in_row > relative_point_idx; if (add) { - int slice = add ? msm[idx + m].wnaf_slices[j] : 0; - update_read_counts(pc - idx - m, slice); + int slice = msm[point_idx].wnaf_digits[digit_idx]; + // pc starts at total_number_of_muls and decreses non-uniformly to 0 + update_read_count((total_number_of_muls - pc) + point_idx, slice); } } } - if (j == num_rounds - 1) { - for (size_t k = 0; k < rows_per_round; ++k) { - const size_t points_per_row = - (k + 1) * ADDITIONS_PER_ROW > msm_size ? msm_size % ADDITIONS_PER_ROW : ADDITIONS_PER_ROW; - const size_t idx = k * ADDITIONS_PER_ROW; - for (size_t m = 0; m < 4; ++m) { - bool add = points_per_row > m; - + if (digit_idx == NUM_WNAF_DIGITS_PER_SCALAR - 1) { + for (size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) { + const size_t num_points_in_row = (row_idx + 1) * ADDITIONS_PER_ROW > msm_size + ? (msm_size % ADDITIONS_PER_ROW) + : ADDITIONS_PER_ROW; + const size_t offset = row_idx * ADDITIONS_PER_ROW; + for (size_t relative_point_idx = 0; relative_point_idx < ADDITIONS_PER_ROW; + ++relative_point_idx) { + bool add = num_points_in_row > relative_point_idx; + const size_t point_idx = offset + relative_point_idx; if (add) { - update_read_counts(pc - idx - m, msm[idx + m].wnaf_skew ? -1 : -15); + // pc starts at total_number_of_muls and decreses non-uniformly to 0 + int slice = msm[point_idx].wnaf_skew ? -1 : -15; + update_read_count((total_number_of_muls - pc) + point_idx, slice); } } } @@ -184,80 +186,84 @@ class ECCVMMSMMBuilder { // The execution trace data for the MSM columns requires knowledge of intermediate values from *affine* point // addition. The naive solution to compute this data requires 2 field inversions per in-circuit group addition - // evaluation. This is bad! To avoid this, we split the witness computation algorithm into 3 steps. Step 1: - // compute the execution trace group operations in *projective* coordinates Step 2: use batch inversion trick to - // convert all point traces into affine coordinates Step 3: populate the full execution trace, including the - // intermediate values from affine group operations This section sets up the data structures we need to store - // all intermediate ECC operations in projective form + // evaluation. This is bad! To avoid this, we split the witness computation algorithm into 3 steps. + // Step 1: compute the execution trace group operations in *projective* coordinates + // Step 2: use batch inversion trick to convert all points into affine coordinates + // Step 3: populate the full execution trace, including the intermediate values from affine group operations + // This section sets up the data structures we need to store all intermediate ECC operations in projective form const size_t num_point_adds_and_doubles = (num_msm_rows - 2) * 4; const size_t num_accumulators = num_msm_rows - 1; - const size_t num_points_in_trace = (num_point_adds_and_doubles * 3) + num_accumulators; + // In what fallows, either p1 + p2 = p3, or p1.dbl() = p3 // We create 1 vector to store the entire point trace. We split into multiple containers using std::span // (we want 1 vector object to more efficiently batch normalize points) - std::vector point_trace(num_points_in_trace); - // the point traces record group operations. Either p1 + p2 = p3, or p1.dbl() = p3 - std::span p1_trace(&point_trace[0], num_point_adds_and_doubles); - std::span p2_trace(&point_trace[num_point_adds_and_doubles], num_point_adds_and_doubles); - std::span p3_trace(&point_trace[num_point_adds_and_doubles * 2], num_point_adds_and_doubles); + static constexpr size_t NUM_POINTS_IN_ADDITION_RELATION = 3; + const size_t num_points_to_normalize = + (num_point_adds_and_doubles * NUM_POINTS_IN_ADDITION_RELATION) + num_accumulators; + std::vector points_to_normalize(num_points_to_normalize); + std::span p1_trace(&points_to_normalize[0], num_point_adds_and_doubles); + std::span p2_trace(&points_to_normalize[num_point_adds_and_doubles], num_point_adds_and_doubles); + std::span p3_trace(&points_to_normalize[num_point_adds_and_doubles * 2], num_point_adds_and_doubles); // operation_trace records whether an entry in the p1/p2/p3 trace represents a point addition or doubling std::vector operation_trace(num_point_adds_and_doubles); // accumulator_trace tracks the value of the ECCVM accumulator for each row - std::span accumulator_trace(&point_trace[num_point_adds_and_doubles * 3], num_accumulators); + std::span accumulator_trace(&points_to_normalize[num_point_adds_and_doubles * 3], num_accumulators); // we start the accumulator at the point at infinity accumulator_trace[0] = (CycleGroup::affine_point_at_infinity); // TODO(https://github.com/AztecProtocol/barretenberg/issues/973): Reinstate multitreading? - // populate point trace data, and the components of the MSM execution trace that do not relate to affine point + // populate point trace, and the components of the MSM execution trace that do not relate to affine point // operations - for (size_t i = 0; i < msms.size(); i++) { + for (size_t msm_idx = 0; msm_idx < msms.size(); msm_idx++) { Element accumulator = CycleGroup::affine_point_at_infinity; - const auto& msm = msms[i]; - size_t msm_row_index = msm_row_indices[i]; + const auto& msm = msms[msm_idx]; + size_t msm_row_index = msm_row_counts[msm_idx]; const size_t msm_size = msm.size(); - const size_t rows_per_round = (msm_size / ADDITIONS_PER_ROW) + (msm_size % ADDITIONS_PER_ROW != 0 ? 1 : 0); - size_t trace_index = (msm_row_indices[i] - 1) * 4; - - for (size_t j = 0; j < num_rounds; ++j) { - const uint32_t pc = static_cast(pc_indices[i]); - - for (size_t k = 0; k < rows_per_round; ++k) { - const size_t points_per_row = - (k + 1) * ADDITIONS_PER_ROW > msm_size ? msm_size % ADDITIONS_PER_ROW : ADDITIONS_PER_ROW; - auto& row = msm_state[msm_row_index]; - const size_t idx = k * ADDITIONS_PER_ROW; - row.msm_transition = (j == 0) && (k == 0); - for (size_t m = 0; m < ADDITIONS_PER_ROW; ++m) { - - auto& add_state = row.add_state[m]; - add_state.add = points_per_row > m; - int slice = add_state.add ? msm[idx + m].wnaf_slices[j] : 0; + const size_t num_rows_per_digit = + (msm_size / ADDITIONS_PER_ROW) + ((msm_size % ADDITIONS_PER_ROW != 0) ? 1 : 0); + size_t trace_index = (msm_row_counts[msm_idx] - 1) * 4; + + for (size_t digit_idx = 0; digit_idx < NUM_WNAF_DIGITS_PER_SCALAR; ++digit_idx) { + const auto pc = static_cast(pc_values[msm_idx]); + for (size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) { + const size_t num_points_in_row = (row_idx + 1) * ADDITIONS_PER_ROW > msm_size + ? (msm_size % ADDITIONS_PER_ROW) + : ADDITIONS_PER_ROW; + auto& row = msm_rows[msm_row_index]; + const size_t offset = row_idx * ADDITIONS_PER_ROW; + row.msm_transition = (digit_idx == 0) && (row_idx == 0); + for (size_t point_idx = 0; point_idx < ADDITIONS_PER_ROW; ++point_idx) { + + auto& add_state = row.add_state[point_idx]; + add_state.add = num_points_in_row > point_idx; + int slice = add_state.add ? msm[offset + point_idx].wnaf_digits[digit_idx] : 0; // In the MSM columns in the ECCVM circuit, we can add up to 4 points per row. - // if `row.add_state[m].add = 1`, this indicates that we want to add the `m`'th point in - // the MSM columns into the MSM accumulator `add_state.slice` = A 4-bit WNAF slice of - // the scalar multiplier associated with the point we are adding (the specific slice - // chosen depends on the value of msm_round) (WNAF = windowed-non-adjacent-form. Value - // range is `-15, -13, + // if `row.add_state[point_idx].add = 1`, this indicates that we want to add the + // `point_idx`'th point in the MSM columns into the MSM accumulator `add_state.slice` = A + // 4-bit WNAF slice of the scalar multiplier associated with the point we are adding (the + // specific slice chosen depends on the value of msm_round) (WNAF = + // windowed-non-adjacent-form. Value range is `-15, -13, // ..., 15`) If `add_state.add = 1`, we want `add_state.slice` to be the *compressed* // form of the WNAF slice value. (compressed = no gaps in the value range. i.e. -15, // -13, ..., 15 maps to 0, ... , 15) add_state.slice = add_state.add ? (slice + 15) / 2 : 0; - add_state.point = add_state.add - ? msm[idx + m].precomputed_table[static_cast(add_state.slice)] - : AffineElement{ 0, 0 }; + add_state.point = + add_state.add + ? msm[offset + point_idx].precomputed_table[static_cast(add_state.slice)] + : AffineElement{ 0, 0 }; // predicate logic: // add_predicate should normally equal add_state.add - // However! if j == 0 AND k == 0 AND m == 0 this implies we are examing the 1st point - // addition of a new MSM In this case, we do NOT add the 1st point into the accumulator, - // instead we SET the accumulator to equal the 1st point. add_predicate is used to - // determine whether we add the output of a point addition into the accumulator, - // therefore if j == 0 AND k == 0 AND m == 0, add_predicate = 0 even if add_state.add = - // true - bool add_predicate = (m == 0 ? (j != 0 || k != 0) : add_state.add); + // However! if digit_idx == 0 AND row_idx == 0 AND point_idx == 0 this implies we are + // examing the 1st point addition of a new MSM. In this case, we do NOT add the 1st point + // into the accumulator, instead we SET the accumulator to equal the 1st point. + // add_predicate is used to determine whether we add the output of a point addition into the + // accumulator, therefore if digit_idx == 0 AND row_idx == 0 AND point_idx == 0, + // add_predicate = 0 even if add_state.add = true + bool add_predicate = (point_idx == 0 ? (digit_idx != 0 || row_idx != 0) : add_state.add); - Element p1 = (m == 0) ? Element(add_state.point) : accumulator; - Element p2 = (m == 0) ? accumulator : Element(add_state.point); + Element p1 = (point_idx == 0) ? Element(add_state.point) : accumulator; + Element p2 = (point_idx == 0) ? accumulator : Element(add_state.point); accumulator = add_predicate ? (accumulator + add_state.point) : Element(p1); p1_trace[trace_index] = p1; @@ -270,25 +276,24 @@ class ECCVMMSMMBuilder { row.q_add = true; row.q_double = false; row.q_skew = false; - row.msm_round = static_cast(j); + row.msm_round = static_cast(digit_idx); row.msm_size = static_cast(msm_size); - row.msm_count = static_cast(idx); + row.msm_count = static_cast(offset); row.pc = pc; msm_row_index++; } // doubling - if (j < num_rounds - 1) { - auto& row = msm_state[msm_row_index]; + if (digit_idx < NUM_WNAF_DIGITS_PER_SCALAR - 1) { + auto& row = msm_rows[msm_row_index]; row.msm_transition = false; - row.msm_round = static_cast(j + 1); + row.msm_round = static_cast(digit_idx + 1); row.msm_size = static_cast(msm_size); row.msm_count = static_cast(0); row.q_add = false; row.q_double = true; row.q_skew = false; - for (size_t m = 0; m < 4; ++m) { - - auto& add_state = row.add_state[m]; + for (size_t point_idx = 0; point_idx < ADDITIONS_PER_ROW; ++point_idx) { + auto& add_state = row.add_state[point_idx]; add_state.add = false; add_state.slice = 0; add_state.point = { 0, 0 }; @@ -304,25 +309,25 @@ class ECCVMMSMMBuilder { accumulator_trace[msm_row_index] = accumulator; msm_row_index++; } else { - for (size_t k = 0; k < rows_per_round; ++k) { - auto& row = msm_state[msm_row_index]; + for (size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) { + auto& row = msm_rows[msm_row_index]; - const size_t points_per_row = - (k + 1) * ADDITIONS_PER_ROW > msm_size ? msm_size % ADDITIONS_PER_ROW : ADDITIONS_PER_ROW; - const size_t idx = k * ADDITIONS_PER_ROW; + const size_t num_points_in_row = (row_idx + 1) * ADDITIONS_PER_ROW > msm_size + ? msm_size % ADDITIONS_PER_ROW + : ADDITIONS_PER_ROW; + const size_t offset = row_idx * ADDITIONS_PER_ROW; row.msm_transition = false; - Element acc_expected = accumulator; - - for (size_t m = 0; m < 4; ++m) { - auto& add_state = row.add_state[m]; - add_state.add = points_per_row > m; - add_state.slice = add_state.add ? msm[idx + m].wnaf_skew ? 7 : 0 : 0; - - add_state.point = add_state.add - ? msm[idx + m].precomputed_table[static_cast(add_state.slice)] - : AffineElement{ 0, 0 }; - bool add_predicate = add_state.add ? msm[idx + m].wnaf_skew : false; + for (size_t point_idx = 0; point_idx < ADDITIONS_PER_ROW; ++point_idx) { + auto& add_state = row.add_state[point_idx]; + add_state.add = num_points_in_row > point_idx; + add_state.slice = add_state.add ? msm[offset + point_idx].wnaf_skew ? 7 : 0 : 0; + + add_state.point = + add_state.add + ? msm[offset + point_idx].precomputed_table[static_cast(add_state.slice)] + : AffineElement{ 0, 0 }; + bool add_predicate = add_state.add ? msm[offset + point_idx].wnaf_skew : false; auto p1 = accumulator; accumulator = add_predicate ? accumulator + add_state.point : accumulator; p1_trace[trace_index] = p1; @@ -334,9 +339,9 @@ class ECCVMMSMMBuilder { row.q_add = false; row.q_double = false; row.q_skew = true; - row.msm_round = static_cast(j + 1); + row.msm_round = static_cast(digit_idx + 1); row.msm_size = static_cast(msm_size); - row.msm_count = static_cast(idx); + row.msm_count = static_cast(offset); row.pc = pc; accumulator_trace[msm_row_index] = accumulator; msm_row_index++; @@ -346,18 +351,18 @@ class ECCVMMSMMBuilder { } // Normalize the points in the point trace - run_loop_in_parallel(point_trace.size(), [&](size_t start, size_t end) { - Element::batch_normalize(&point_trace[start], end - start); + run_loop_in_parallel(points_to_normalize.size(), [&](size_t start, size_t end) { + Element::batch_normalize(&points_to_normalize[start], end - start); }); // inverse_trace is used to compute the value of the `collision_inverse` column in the ECCVM. std::vector inverse_trace(num_point_adds_and_doubles); run_loop_in_parallel(num_point_adds_and_doubles, [&](size_t start, size_t end) { - for (size_t i = start; i < end; ++i) { - if (operation_trace[i]) { - inverse_trace[i] = (p1_trace[i].y + p1_trace[i].y); + for (size_t operation_idx = start; operation_idx < end; ++operation_idx) { + if (operation_trace[operation_idx]) { + inverse_trace[operation_idx] = (p1_trace[operation_idx].y + p1_trace[operation_idx].y); } else { - inverse_trace[i] = (p2_trace[i].x - p1_trace[i].x); + inverse_trace[operation_idx] = (p2_trace[operation_idx].x - p1_trace[operation_idx].x); } } FF::batch_invert(&inverse_trace[start], end - start); @@ -366,28 +371,29 @@ class ECCVMMSMMBuilder { // complete the computation of the ECCVM execution trace, by adding the affine intermediate point data // i.e. row.accumulator_x, row.accumulator_y, row.add_state[0...3].collision_inverse, // row.add_state[0...3].lambda - for (size_t i = 0; i < msms.size(); i++) { - const auto& msm = msms[i]; - size_t trace_index = ((msm_row_indices[i] - 1) * ADDITIONS_PER_ROW); - size_t msm_row_index = msm_row_indices[i]; + for (size_t msm_idx = 0; msm_idx < msms.size(); msm_idx++) { + const auto& msm = msms[msm_idx]; + size_t trace_index = ((msm_row_counts[msm_idx] - 1) * ADDITIONS_PER_ROW); + size_t msm_row_index = msm_row_counts[msm_idx]; // 1st MSM row will have accumulator equal to the previous MSM output // (or point at infinity for 1st MSM) - size_t accumulator_index = msm_row_indices[i] - 1; + size_t accumulator_index = msm_row_counts[msm_idx] - 1; const size_t msm_size = msm.size(); - const size_t rows_per_round = (msm_size / ADDITIONS_PER_ROW) + (msm_size % ADDITIONS_PER_ROW != 0 ? 1 : 0); + const size_t num_rows_per_digit = + (msm_size / ADDITIONS_PER_ROW) + ((msm_size % ADDITIONS_PER_ROW != 0) ? 1 : 0); - for (size_t j = 0; j < num_rounds; ++j) { - for (size_t k = 0; k < rows_per_round; ++k) { - auto& row = msm_state[msm_row_index]; + for (size_t digit_idx = 0; digit_idx < NUM_WNAF_DIGITS_PER_SCALAR; ++digit_idx) { + for (size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) { + auto& row = msm_rows[msm_row_index]; const Element& normalized_accumulator = accumulator_trace[accumulator_index]; const FF& acc_x = normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.x; const FF& acc_y = normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.y; row.accumulator_x = acc_x; row.accumulator_y = acc_y; - for (size_t m = 0; m < ADDITIONS_PER_ROW; ++m) { - auto& add_state = row.add_state[m]; - bool add_predicate = (m == 0 ? (j != 0 || k != 0) : add_state.add); + for (size_t point_idx = 0; point_idx < ADDITIONS_PER_ROW; ++point_idx) { + auto& add_state = row.add_state[point_idx]; + bool add_predicate = (point_idx == 0 ? (digit_idx != 0 || row_idx != 0) : add_state.add); const auto& inverse = inverse_trace[trace_index]; const auto& p1 = p1_trace[trace_index]; @@ -400,16 +406,15 @@ class ECCVMMSMMBuilder { msm_row_index++; } - if (j < num_rounds - 1) { - MSMState& row = msm_state[msm_row_index]; + if (digit_idx < NUM_WNAF_DIGITS_PER_SCALAR - 1) { + MSMRow& row = msm_rows[msm_row_index]; const Element& normalized_accumulator = accumulator_trace[accumulator_index]; const FF& acc_x = normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.x; const FF& acc_y = normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.y; row.accumulator_x = acc_x; row.accumulator_y = acc_y; - - for (size_t m = 0; m < 4; ++m) { - auto& add_state = row.add_state[m]; + for (size_t point_idx = 0; point_idx < ADDITIONS_PER_ROW; ++point_idx) { + auto& add_state = row.add_state[point_idx]; add_state.collision_inverse = 0; const FF& dx = p1_trace[trace_index].x; const FF& inverse = inverse_trace[trace_index]; @@ -419,20 +424,17 @@ class ECCVMMSMMBuilder { accumulator_index++; msm_row_index++; } else { - for (size_t k = 0; k < rows_per_round; ++k) { - MSMState& row = msm_state[msm_row_index]; + for (size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) { + MSMRow& row = msm_rows[msm_row_index]; const Element& normalized_accumulator = accumulator_trace[accumulator_index]; - - const size_t idx = k * ADDITIONS_PER_ROW; - + const size_t offset = row_idx * ADDITIONS_PER_ROW; const FF& acc_x = normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.x; const FF& acc_y = normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.y; row.accumulator_x = acc_x; row.accumulator_y = acc_y; - - for (size_t m = 0; m < ADDITIONS_PER_ROW; ++m) { - auto& add_state = row.add_state[m]; - bool add_predicate = add_state.add ? msm[idx + m].wnaf_skew : false; + for (size_t point_idx = 0; point_idx < ADDITIONS_PER_ROW; ++point_idx) { + auto& add_state = row.add_state[point_idx]; + bool add_predicate = add_state.add ? msm[offset + point_idx].wnaf_skew : false; const auto& inverse = inverse_trace[trace_index]; const auto& p1 = p1_trace[trace_index]; @@ -452,8 +454,8 @@ class ECCVMMSMMBuilder { // we always require 1 extra row at the end of the trace, because the accumulator x/y coordinates for row `i` // are present at row `i+1` Element final_accumulator(accumulator_trace.back()); - MSMState& final_row = msm_state.back(); - final_row.pc = static_cast(pc_indices.back()); + MSMRow& final_row = msm_rows.back(); + final_row.pc = static_cast(pc_values.back()); final_row.msm_transition = true; final_row.accumulator_x = final_accumulator.is_point_at_infinity() ? 0 : final_accumulator.x; final_row.accumulator_y = final_accumulator.is_point_at_infinity() ? 0 : final_accumulator.y; @@ -462,12 +464,12 @@ class ECCVMMSMMBuilder { final_row.q_add = false; final_row.q_double = false; final_row.q_skew = false; - final_row.add_state = { typename MSMState::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 }, - typename MSMState::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 }, - typename MSMState::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 }, - typename MSMState::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 } }; + final_row.add_state = { typename MSMRow::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 }, + typename MSMRow::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 }, + typename MSMRow::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 }, + typename MSMRow::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 } }; - return msm_state; + return { msm_rows, point_table_read_counts }; } }; } // namespace bb diff --git a/barretenberg/cpp/src/barretenberg/eccvm/precomputed_tables_builder.hpp b/barretenberg/cpp/src/barretenberg/eccvm/precomputed_tables_builder.hpp index ed77be8f6a6..c98e1d56b8b 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/precomputed_tables_builder.hpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/precomputed_tables_builder.hpp @@ -4,18 +4,18 @@ namespace bb { -class ECCVMPrecomputedTablesBuilder { +class ECCVMPointTablePrecomputationBuilder { public: using CycleGroup = bb::g1; using FF = grumpkin::fr; using Element = typename CycleGroup::element; using AffineElement = typename CycleGroup::affine_element; - static constexpr size_t NUM_WNAF_SLICES = bb::eccvm::NUM_WNAF_SLICES; - static constexpr size_t WNAF_SLICES_PER_ROW = bb::eccvm::WNAF_SLICES_PER_ROW; - static constexpr size_t WNAF_SLICE_BITS = bb::eccvm::WNAF_SLICE_BITS; + static constexpr size_t NUM_WNAF_DIGITS_PER_SCALAR = bb::eccvm::NUM_WNAF_DIGITS_PER_SCALAR; + static constexpr size_t WNAF_DIGITS_PER_ROW = bb::eccvm::WNAF_DIGITS_PER_ROW; + static constexpr size_t NUM_WNAF_DIGIT_BITS = bb::eccvm::NUM_WNAF_DIGIT_BITS; - struct PrecomputeState { + struct PointTablePrecoputationRow { int s1 = 0; int s2 = 0; int s3 = 0; @@ -33,31 +33,31 @@ class ECCVMPrecomputedTablesBuilder { AffineElement precompute_double{ 0, 0 }; }; - static std::vector compute_precompute_state( + static std::vector compute_rows( const std::vector>& ecc_muls) { - static constexpr size_t num_rows_per_scalar = NUM_WNAF_SLICES / WNAF_SLICES_PER_ROW; + static constexpr size_t num_rows_per_scalar = NUM_WNAF_DIGITS_PER_SCALAR / WNAF_DIGITS_PER_ROW; const size_t num_precompute_rows = num_rows_per_scalar * ecc_muls.size() + 1; - std::vector precompute_state(num_precompute_rows); + std::vector precompute_state(num_precompute_rows); // start with empty row (shiftable polynomials must have 0 as first coefficient) - precompute_state[0] = PrecomputeState{}; + precompute_state[0] = PointTablePrecoputationRow{}; // current impl doesn't work if not 4 - static_assert(WNAF_SLICES_PER_ROW == 4); + static_assert(WNAF_DIGITS_PER_ROW == 4); run_loop_in_parallel(ecc_muls.size(), [&](size_t start, size_t end) { for (size_t j = start; j < end; j++) { const auto& entry = ecc_muls[j]; - const auto& slices = entry.wnaf_slices; + const auto& slices = entry.wnaf_digits; uint256_t scalar_sum = 0; for (size_t i = 0; i < num_rows_per_scalar; ++i) { - PrecomputeState row; - const int slice0 = slices[i * WNAF_SLICES_PER_ROW]; - const int slice1 = slices[i * WNAF_SLICES_PER_ROW + 1]; - const int slice2 = slices[i * WNAF_SLICES_PER_ROW + 2]; - const int slice3 = slices[i * WNAF_SLICES_PER_ROW + 3]; + PointTablePrecoputationRow row; + const int slice0 = slices[i * WNAF_DIGITS_PER_ROW]; + const int slice1 = slices[i * WNAF_DIGITS_PER_ROW + 1]; + const int slice2 = slices[i * WNAF_DIGITS_PER_ROW + 2]; + const int slice3 = slices[i * WNAF_DIGITS_PER_ROW + 3]; const int slice0base2 = (slice0 + 15) / 2; const int slice1base2 = (slice1 + 15) / 2; @@ -85,7 +85,7 @@ class ECCVMPrecomputedTablesBuilder { bool chunk_negative = row_chunk < 0; - scalar_sum = scalar_sum << (WNAF_SLICE_BITS * WNAF_SLICES_PER_ROW); + scalar_sum = scalar_sum << (NUM_WNAF_DIGIT_BITS * WNAF_DIGITS_PER_ROW); if (chunk_negative) { scalar_sum -= static_cast(-row_chunk); } else { diff --git a/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp b/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp index 106d83b5d4b..b3d93d3d1f8 100644 --- a/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp +++ b/barretenberg/cpp/src/barretenberg/eccvm/transcript_builder.hpp @@ -11,7 +11,7 @@ class ECCVMTranscriptBuilder { using Element = typename CycleGroup::element; using AffineElement = typename CycleGroup::affine_element; - struct TranscriptState { + struct TranscriptRow { bool accumulator_empty = false; bool q_add = false; bool q_mul = false; @@ -57,12 +57,12 @@ class ECCVMTranscriptBuilder { return res; } }; - static std::vector compute_transcript_state( - const std::vector>& vm_operations, const uint32_t total_number_of_muls) + static std::vector compute_rows(const std::vector>& vm_operations, + const uint32_t total_number_of_muls) { const size_t num_transcript_entries = vm_operations.size() + 2; - std::vector transcript_state(num_transcript_entries); + std::vector transcript_state(num_transcript_entries); std::vector inverse_trace(num_transcript_entries - 2); VMState state{ .pc = total_number_of_muls, @@ -73,9 +73,9 @@ class ECCVMTranscriptBuilder { }; VMState updated_state; // add an empty row. 1st row all zeroes because of our shiftable polynomials - transcript_state[0] = (TranscriptState{}); + transcript_state[0] = (TranscriptRow{}); for (size_t i = 0; i < vm_operations.size(); ++i) { - TranscriptState& row = transcript_state[i + 1]; + TranscriptRow& row = transcript_state[i + 1]; const bb::eccvm::VMOperation& entry = vm_operations[i]; const bool is_mul = entry.mul; @@ -180,7 +180,7 @@ class ECCVMTranscriptBuilder { for (size_t i = 0; i < inverse_trace.size(); ++i) { transcript_state[i + 1].collision_check = inverse_trace[i]; } - TranscriptState& final_row = transcript_state.back(); + TranscriptRow& final_row = transcript_state.back(); final_row.pc = updated_state.pc; final_row.accumulator_x = (updated_state.accumulator.is_point_at_infinity()) ? 0 : updated_state.accumulator.x; final_row.accumulator_y = (updated_state.accumulator.is_point_at_infinity()) ? 0 : updated_state.accumulator.y; diff --git a/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp b/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp index 4ef2ef12ef8..c3f04728cd3 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp @@ -261,18 +261,19 @@ class ECCOpQueue { } /** - * @brief Get the number of rows in the 'msm' column section o the ECCVM, associated with a single multiscalar mul + * @brief Get the number of rows in the 'msm' column section of the ECCVM associated with a single multiscalar + * multiplication. * - * @param msm_count + * @param msm_size * @return uint32_t */ - static uint32_t get_msm_row_count_for_single_msm(const size_t msm_count) + static uint32_t num_eccvm_msm_rows(const size_t msm_size) { - const size_t rows_per_round = - (msm_count / eccvm::ADDITIONS_PER_ROW) + (msm_count % eccvm::ADDITIONS_PER_ROW != 0 ? 1 : 0); - constexpr size_t num_rounds = eccvm::NUM_SCALAR_BITS / eccvm::WNAF_SLICE_BITS; - const size_t num_rows_for_all_rounds = (num_rounds + 1) * rows_per_round; // + 1 round for skew - const size_t num_double_rounds = num_rounds - 1; + const size_t rows_per_wnaf_digit = + (msm_size / eccvm::ADDITIONS_PER_ROW) + ((msm_size % eccvm::ADDITIONS_PER_ROW != 0) ? 1 : 0); + const size_t num_rows_for_all_rounds = + (eccvm::NUM_WNAF_DIGITS_PER_SCALAR + 1) * rows_per_wnaf_digit; // + 1 round for skew + const size_t num_double_rounds = eccvm::NUM_WNAF_DIGITS_PER_SCALAR - 1; const size_t num_rows_for_msm = num_rows_for_all_rounds + num_double_rounds; return static_cast(num_rows_for_msm); @@ -287,7 +288,7 @@ class ECCOpQueue { { size_t msm_rows = num_msm_rows + 2; if (cached_active_msm_count > 0) { - msm_rows += get_msm_row_count_for_single_msm(cached_active_msm_count); + msm_rows += num_eccvm_msm_rows(cached_active_msm_count); } return msm_rows; } @@ -305,7 +306,7 @@ class ECCOpQueue { // add 1 row to start of precompute table section size_t precompute_rows = num_precompute_table_rows + 1; if (cached_active_msm_count > 0) { - msm_rows += get_msm_row_count_for_single_msm(cached_active_msm_count); + msm_rows += num_eccvm_msm_rows(cached_active_msm_count); precompute_rows += get_precompute_table_row_count_for_single_msm(cached_active_msm_count); } @@ -323,7 +324,7 @@ class ECCOpQueue { accumulator = accumulator + to_add; // Construct and store the operation in the ultra op format - auto ultra_op = construct_and_populate_ultra_ops(ADD_ACCUM, to_add); + UltraOp ultra_op = construct_and_populate_ultra_ops(ADD_ACCUM, to_add); // Store the raw operation raw_ops.emplace_back(ECCVMOperation{ @@ -353,7 +354,7 @@ class ECCOpQueue { accumulator = accumulator + to_mul * scalar; // Construct and store the operation in the ultra op format - auto ultra_op = construct_and_populate_ultra_ops(MUL_ACCUM, to_mul, scalar); + UltraOp ultra_op = construct_and_populate_ultra_ops(MUL_ACCUM, to_mul, scalar); // Store the raw operation raw_ops.emplace_back(ECCVMOperation{ @@ -383,7 +384,7 @@ class ECCOpQueue { accumulator.self_set_infinity(); // Construct and store the operation in the ultra op format - auto ultra_op = construct_and_populate_ultra_ops(EQUALITY, expected); + UltraOp ultra_op = construct_and_populate_ultra_ops(EQUALITY, expected); // Store raw operation raw_ops.emplace_back(ECCVMOperation{ @@ -404,7 +405,9 @@ class ECCOpQueue { private: /** - * @brief when inserting operations, update the number of multiplications in the latest scalar mul + * @brief Update cached_active_msm_count or update other row counts and reset cached_active_msm_count. + * @details To the OpQueue, an MSM is a sequence of successive mul opcodes (note that mul might better be called + * mul_add--its effect on the accumulator is += scalar * point). * * @param op */ @@ -418,7 +421,7 @@ class ECCOpQueue { cached_active_msm_count++; } } else if (cached_active_msm_count != 0) { - num_msm_rows += get_msm_row_count_for_single_msm(cached_active_msm_count); + num_msm_rows += num_eccvm_msm_rows(cached_active_msm_count); num_precompute_table_rows += get_precompute_table_row_count_for_single_msm(cached_active_msm_count); cached_num_muls += cached_active_msm_count; cached_active_msm_count = 0; @@ -433,7 +436,8 @@ class ECCOpQueue { */ static uint32_t get_precompute_table_row_count_for_single_msm(const size_t msm_count) { - constexpr size_t num_precompute_rows_per_scalar = eccvm::NUM_WNAF_SLICES / eccvm::WNAF_SLICES_PER_ROW; + constexpr size_t num_precompute_rows_per_scalar = + eccvm::NUM_WNAF_DIGITS_PER_SCALAR / eccvm::WNAF_DIGITS_PER_ROW; const size_t num_rows_for_precompute_table = msm_count * num_precompute_rows_per_scalar; return static_cast(num_rows_for_precompute_table); }