Skip to content

Commit

Permalink
fix: Biggroup batch mul handles collisions (#6780)
Browse files Browse the repository at this point in the history
The following PR adds an edgecase handling mode to biggroup batch
multiplication. In this mode the points are randomised in such a way as
avoid weird interactions (doublings and point at infinity cases).
It enables using tables for RecurisveMergeVerifier and Recursive
Verifier for Protogalaxy on ultra (it also halves the gatecount from 4.5
mln to 2.2).
For a batch multiplication of 5 points it increases the gate count in
ultra from ~72k to ~78k

---------

Co-authored-by: Rumata888 <isennovskiy@gmail.com>
  • Loading branch information
codygunton and Rumata888 authored Jun 5, 2024
1 parent 05697f2 commit e61c40e
Show file tree
Hide file tree
Showing 10 changed files with 221 additions and 118 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ std::array<typename bn254<CircuitBuilder>::Element, 2> MergeRecursiveVerifier_<C
alpha_pow *= alpha;
}

auto batched_commitment = Commitment::batch_mul(commitments, scalars);
auto batched_commitment = Commitment::batch_mul(commitments, scalars, /*max_num_bits=*/0, /*with_edgecases=*/true);

OpeningClaim batched_claim = { { kappa, batched_eval }, batched_commitment };

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ template <class RecursiveBuilder> class RecursiveMergeVerifierTest : public test
// Run the recursive verifier tests with Ultra and Mega builders
// TODO(https://github.com/AztecProtocol/barretenberg/issues/1024): Ultra fails, possibly due to repeated points in
// batch mul?
// using Builders = testing::Types<MegaCircuitBuilder, UltraCircuitBuilder>;
using Builders = testing::Types<MegaCircuitBuilder>;
using Builders = testing::Types<MegaCircuitBuilder, UltraCircuitBuilder>;

TYPED_TEST_SUITE(RecursiveMergeVerifierTest, Builders);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,51 +125,6 @@ template <class VerifierInstances> class ProtoGalaxyRecursiveVerifier_ {
return result;
};

/**
* @brief Hack method to fold the witness commitments and verification key without the batch_mul in the case where
* the recursive folding verifier is instantiated as a vanilla ultra circuit.
*
* @details In the folding recursive verifier we might hit the scenerio where we do a batch_mul(commitments,
* lagranges) where the commitments are equal. That is because when we add gates to ensure no zero commitments,
* these will be the same for all circuits, hitting an edge case in batch_mul that creates a failing constraint.
* Specifically, at some point in the algorithm we compute the difference between the points which, if they are
* equal, would be zero, case that is not supported. See https://github.com/AztecProtocol/barretenberg/issues/971.
*/
void fold_commitments(std::vector<FF> lagranges,
VerifierInstances& instances,
std::shared_ptr<Instance>& accumulator)
requires IsUltraBuilder<Builder>
{
using ElementNative = typename Flavor::Curve::ElementNative;
using AffineElementNative = typename Flavor::Curve::AffineElementNative;

auto offset_generator = Commitment::from_witness(builder, AffineElementNative(ElementNative::random_element()));

size_t vk_idx = 0;
for (auto& expected_vk : accumulator->verification_key->get_all()) {
expected_vk = offset_generator;
size_t inst = 0;
for (auto& instance : instances) {
expected_vk += instance->verification_key->get_all()[vk_idx] * lagranges[inst];
inst++;
}
expected_vk -= offset_generator;
vk_idx++;
}

size_t comm_idx = 0;
for (auto& comm : accumulator->witness_commitments.get_all()) {
comm = offset_generator;
size_t inst = 0;
for (auto& instance : instances) {
comm += instance->witness_commitments.get_all()[comm_idx] * lagranges[inst];
inst++;
}
comm -= offset_generator;
comm_idx++;
}
}

/**
* @brief Folds the witness commitments and verification key (part of ϕ) and stores the values in the accumulator.
*
Expand All @@ -179,15 +134,16 @@ template <class VerifierInstances> class ProtoGalaxyRecursiveVerifier_ {
void fold_commitments(std::vector<FF> lagranges,
VerifierInstances& instances,
std::shared_ptr<Instance>& accumulator)
requires(!IsUltraBuilder<Builder>)
{
size_t vk_idx = 0;
for (auto& expected_vk : accumulator->verification_key->get_all()) {
std::vector<Commitment> commitments;
for (auto& instance : instances) {
commitments.emplace_back(instance->verification_key->get_all()[vk_idx]);
}
expected_vk = Commitment::batch_mul(commitments, lagranges);
// For ultra we need to enable edgecase prevention
expected_vk = Commitment::batch_mul(
commitments, lagranges, /*max_num_bits=*/0, /*with_edgecases=*/IsUltraBuilder<Builder>);
vk_idx++;
}

Expand All @@ -197,7 +153,9 @@ template <class VerifierInstances> class ProtoGalaxyRecursiveVerifier_ {
for (auto& instance : instances) {
commitments.emplace_back(instance->witness_commitments.get_all()[comm_idx]);
}
comm = Commitment::batch_mul(commitments, lagranges);
// For ultra we need to enable edgecase prevention
comm = Commitment::batch_mul(
commitments, lagranges, /*max_num_bits=*/0, /*with_edgecases=*/IsUltraBuilder<Builder>);
comm_idx++;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ template <class Builder, class Fq, class Fr, class NativeGroup> class element {
*this = *this - other;
return *this;
}
std::array<element, 2> checked_unconditional_add_sub(const element& other) const;
std::array<element, 2> checked_unconditional_add_sub(const element&) const;

element operator*(const Fr& other) const;

Expand Down Expand Up @@ -204,6 +204,9 @@ template <class Builder, class Fq, class Fr, class NativeGroup> class element {
return result;
}

static std::pair<std::vector<element>, std::vector<Fr>> mask_points(const std::vector<element>& _points,
const std::vector<Fr>& _scalars);

static std::pair<std::vector<element>, std::vector<Fr>> handle_points_at_infinity(
const std::vector<element>& _points, const std::vector<Fr>& _scalars);

Expand All @@ -215,7 +218,8 @@ template <class Builder, class Fq, class Fr, class NativeGroup> class element {
static element wnaf_batch_mul(const std::vector<element>& points, const std::vector<Fr>& scalars);
static element batch_mul(const std::vector<element>& points,
const std::vector<Fr>& scalars,
const size_t max_num_bits = 0);
const size_t max_num_bits = 0,
const bool with_edgecases = false);

// TODO(https://github.com/AztecProtocol/barretenberg/issues/707) max_num_bits is unused; could implement and use
// this to optimize other operations.
Expand Down Expand Up @@ -310,6 +314,7 @@ template <class Builder, class Fq, class Fr, class NativeGroup> class element {
const std::array<uint256_t, 8>& limb_max);

static std::pair<element, element> compute_offset_generators(const size_t num_rounds);
static typename NativeGroup::affine_element compute_table_offset_generator();

template <typename = typename std::enable_if<HasPlookup<Builder>>> struct four_bit_table_plookup {
four_bit_table_plookup(){};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -449,30 +449,72 @@ template <typename TestType> class stdlib_biggroup : public testing::Test {
EXPECT_CIRCUIT_CORRECTNESS(builder);
}

static void test_batch_mul_edge_cases()
static void test_batch_mul_edgecase_equivalence()
{
{
// batch P + P = 2P
const size_t num_points = 5;
Builder builder;
std::vector<affine_element> points;
std::vector<fr> scalars;
for (size_t i = 0; i < num_points; ++i) {
points.push_back(affine_element(element::random_element()));
scalars.push_back(fr::random_element());
}

std::vector<element_ct> circuit_points;
std::vector<scalar_ct> circuit_scalars;
for (size_t i = 0; i < num_points; ++i) {
circuit_points.push_back(element_ct::from_witness(&builder, points[i]));
circuit_scalars.push_back(scalar_ct::from_witness(&builder, scalars[i]));
}

element_ct result_point2 =
element_ct::batch_mul(circuit_points, circuit_scalars, /*max_num_bits=*/0, /*with_edgecases=*/true);

element expected_point = g1::one;
expected_point.self_set_infinity();
for (size_t i = 0; i < num_points; ++i) {
expected_point += (element(points[i]) * scalars[i]);
}

expected_point = expected_point.normalize();

fq result2_x(result_point2.x.get_value().lo);
fq result2_y(result_point2.y.get_value().lo);

EXPECT_EQ(result2_x, expected_point.x);
EXPECT_EQ(result2_y, expected_point.y);

EXPECT_CIRCUIT_CORRECTNESS(builder);
}

static void test_batch_mul_edge_case_set1()
{
const auto test_repeated_points = [](const uint32_t num_points) {
// batch P + ... + P = m*P
info("num points: ", num_points);
std::vector<affine_element> points;
points.push_back(affine_element::one());
points.push_back(affine_element::one());
std::vector<fr> scalars;
scalars.push_back(1);
scalars.push_back(1);
for (size_t idx = 0; idx < num_points; idx++) {
points.push_back(affine_element::one());
scalars.push_back(1);
}

Builder builder;
ASSERT(points.size() == scalars.size());
const size_t num_points = points.size();

std::vector<element_ct> circuit_points;
std::vector<scalar_ct> circuit_scalars;
for (size_t i = 0; i < num_points; ++i) {
circuit_points.push_back(element_ct::from_witness(&builder, points[i]));
circuit_scalars.push_back(scalar_ct::from_witness(&builder, scalars[i]));
}
element_ct result_point = element_ct::batch_mul(circuit_points, circuit_scalars);
element_ct result_point =
element_ct::batch_mul(circuit_points, circuit_scalars, /*max_num_bits=*/0, /*with_edgecases=*/true);

element expected_point = points[0] + points[1];
auto expected_point = element::infinity();
for (const auto& point : points) {
expected_point += point;
}
expected_point = expected_point.normalize();

fq result_x(result_point.x.get_value().lo);
Expand All @@ -482,7 +524,16 @@ template <typename TestType> class stdlib_biggroup : public testing::Test {
EXPECT_EQ(result_y, expected_point.y);

EXPECT_CIRCUIT_CORRECTNESS(builder);
}
};
test_repeated_points(2);
test_repeated_points(3);
test_repeated_points(4);
test_repeated_points(5);
test_repeated_points(6);
test_repeated_points(7);
}
static void test_batch_mul_edge_case_set2()
{
{
// batch oo + P = P
std::vector<affine_element> points;
Expand All @@ -502,7 +553,8 @@ template <typename TestType> class stdlib_biggroup : public testing::Test {
circuit_points.push_back(element_ct::from_witness(&builder, points[i]));
circuit_scalars.push_back(scalar_ct::from_witness(&builder, scalars[i]));
}
element_ct result_point = element_ct::batch_mul(circuit_points, circuit_scalars);
element_ct result_point =
element_ct::batch_mul(circuit_points, circuit_scalars, /*max_num_bits=*/0, /*with_edgecases=*/true);

element expected_point = points[1];
expected_point = expected_point.normalize();
Expand Down Expand Up @@ -535,7 +587,8 @@ template <typename TestType> class stdlib_biggroup : public testing::Test {
circuit_scalars.push_back(scalar_ct::from_witness(&builder, scalars[i]));
}

element_ct result_point = element_ct::batch_mul(circuit_points, circuit_scalars);
element_ct result_point =
element_ct::batch_mul(circuit_points, circuit_scalars, /*max_num_bits=*/0, /*with_edgecases=*/true);

element expected_point = points[1];
expected_point = expected_point.normalize();
Expand Down Expand Up @@ -1177,10 +1230,24 @@ HEAVY_TYPED_TEST(stdlib_biggroup, batch_mul)
{
TestFixture::test_batch_mul();
}
HEAVY_TYPED_TEST(stdlib_biggroup, batch_mul_edge_cases)

HEAVY_TYPED_TEST(stdlib_biggroup, batch_mul_edgecase_equivalence)
{
if constexpr (HasGoblinBuilder<TypeParam>) {
GTEST_SKIP();
} else {
TestFixture::test_batch_mul_edgecase_equivalence();
}
}
HEAVY_TYPED_TEST(stdlib_biggroup, batch_mul_edge_case_set1)
{
TestFixture::test_batch_mul_edge_case_set1();
}

HEAVY_TYPED_TEST(stdlib_biggroup, batch_mul_edge_case_set2)
{
if constexpr (HasGoblinBuilder<TypeParam>) {
TestFixture::test_batch_mul_edge_cases();
TestFixture::test_batch_mul_edge_case_set2();
} else {
GTEST_SKIP() << "https://github.com/AztecProtocol/barretenberg/issues/1000";
};
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include "barretenberg/stdlib/primitives/biggroup/handle_points_at_infinity.hpp"
#include "barretenberg/stdlib/primitives/biggroup/biggroup_edgecase_handling.hpp"
#include <cstddef>
namespace bb::stdlib {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#pragma once
#include "barretenberg/stdlib/primitives/biggroup/biggroup.hpp"

namespace bb::stdlib {

/**
* @brief Compute an offset generator for use in biggroup tables
*
*@details Sometimes the points from which we construct the tables are going to be dependent in such a way that
*combining them for constructing the table is not possible without handling the edgecases such as the point at infinity
*and doubling. To avoid handling those we add multiples of this offset generator to the points.
*
* @param num_rounds
*/
template <typename C, class Fq, class Fr, class G>
typename G::affine_element element<C, Fq, Fr, G>::compute_table_offset_generator()
{
constexpr typename G::affine_element offset_generator =
G::derive_generators("biggroup table offset generator", 1)[0];

return offset_generator;
}

/**
* @brief Given two lists of points that need to be multiplied by scalars, create a new list of length +1 with original
* points masked, but the same scalar product sum
* @details Add +1G, +2G, +4G etc to the original points and adds a new point 2ⁿ⋅G and scalar x to the lists. By
* doubling the point every time, we ensure that no +-1 combination of 6 sequential elements run into edgecases, unless
* the points are deliberately constructed to trigger it.
*/
template <typename C, class Fq, class Fr, class G>
std::pair<std::vector<element<C, Fq, Fr, G>>, std::vector<Fr>> element<C, Fq, Fr, G>::mask_points(
const std::vector<element>& _points, const std::vector<Fr>& _scalars)
{
std::vector<element> points;
std::vector<Fr> scalars;
ASSERT(_points.size() == _scalars.size());
using NativeFr = typename Fr::native;
auto running_scalar = NativeFr::one();
// Get the offset generator G_offset in native and in-circuit form
auto native_offset_generator = element::compute_table_offset_generator();
Fr last_scalar = Fr(0);
NativeFr generator_coefficient = NativeFr(2).pow(_points.size());
auto generator_coefficient_inverse = generator_coefficient.invert();
// For each point and scalar
for (size_t i = 0; i < _points.size(); i++) {
scalars.push_back(_scalars[i]);
// Convert point into point + 2ⁱ⋅G_offset
points.push_back(_points[i] + (native_offset_generator * running_scalar));
// Add \frac{2ⁱ⋅scalar}{2ⁿ} to the last scalar
last_scalar += _scalars[i] * (running_scalar * generator_coefficient_inverse);
// Double the running scalar
running_scalar += running_scalar;
}

// Add a scalar -(<(1,2,4,...,2ⁿ⁻¹ ),(scalar₀,...,scalarₙ₋₁)> / 2ⁿ)
scalars.push_back(-last_scalar);
// Add in-circuit G_offset to points
points.push_back(element(native_offset_generator * generator_coefficient));

return { points, scalars };
}

/**
* @brief Replace all pairs (∞, scalar) by the pair (one, 0) where one is a fixed generator of the curve
* @details This is a step in enabling our our multiscalar multiplication algorithms to hande points at infinity.
*/
template <typename C, class Fq, class Fr, class G>
std::pair<std::vector<element<C, Fq, Fr, G>>, std::vector<Fr>> element<C, Fq, Fr, G>::handle_points_at_infinity(
const std::vector<element>& _points, const std::vector<Fr>& _scalars)
{
auto builder = _points[0].get_context();
std::vector<element> points;
std::vector<Fr> scalars;
element one = element::one(builder);

for (auto [_point, _scalar] : zip_view(_points, _scalars)) {
bool_ct is_point_at_infinity = _point.is_point_at_infinity();
if (is_point_at_infinity.get_value() && static_cast<bool>(is_point_at_infinity.is_constant())) {
// if point is at infinity and a circuit constant we can just skip.
continue;
}
if (_scalar.get_value() == 0 && _scalar.is_constant()) {
// if scalar multiplier is 0 and also a constant, we can skip
continue;
}
Fq updated_x = Fq::conditional_assign(is_point_at_infinity, one.x, _point.x);
Fq updated_y = Fq::conditional_assign(is_point_at_infinity, one.y, _point.y);
element point(updated_x, updated_y);
Fr scalar = Fr::conditional_assign(is_point_at_infinity, 0, _scalar);

points.push_back(point);
scalars.push_back(scalar);
// TODO(https://github.com/AztecProtocol/barretenberg/issues/1002): if both point and scalar are constant, don't
// bother adding constraints
}

return { points, scalars };
}
} // namespace bb::stdlib
Loading

0 comments on commit e61c40e

Please sign in to comment.