diff --git a/Earthfile b/Earthfile index de2ca028b..ba15020ce 100644 --- a/Earthfile +++ b/Earthfile @@ -78,8 +78,10 @@ barretenberg-acir-tests-bb.js: ENV VERBOSE=1 ENV TEST_SRC /usr/src/acir_artifacts - # Run double_verify_proof through bb.js on node to check 512k support. - RUN BIN=../ts/dest/node/main.js FLOW=prove_then_verify ./run_acir_tests.sh double_verify_proof + # TODO(https://github.com/noir-lang/noir/issues/5106) + # TODO(https://github.com/AztecProtocol/aztec-packages/issues/6672)c + # Run ecdsa_secp256r1_3x through bb.js on node to check 256k support. + RUN BIN=../ts/dest/node/main.js FLOW=prove_then_verify ./run_acir_tests.sh ecdsa_secp256r1_3x # Run a single arbitrary test not involving recursion through bb.js for UltraHonk RUN BIN=../ts/dest/node/main.js FLOW=prove_and_verify_ultra_honk ./run_acir_tests.sh 6_array # Run a single arbitrary test not involving recursion through bb.js for MegaHonk @@ -88,11 +90,13 @@ barretenberg-acir-tests-bb.js: RUN BIN=../ts/dest/node/main.js FLOW=prove_and_verify_goblin ./run_acir_tests.sh 6_array # Run 1_mul through bb.js build, all_cmds flow, to test all cli args. RUN BIN=../ts/dest/node/main.js FLOW=all_cmds ./run_acir_tests.sh 1_mul - # Run double_verify_proof through bb.js on chrome testing multi-threaded browser support. + # TODO(https://github.com/AztecProtocol/aztec-packages/issues/6672) + # Run 6_array through bb.js on chrome testing multi-threaded browser support. # TODO: Currently headless webkit doesn't seem to have shared memory so skipping multi-threaded test. - RUN BROWSER=chrome THREAD_MODEL=mt ./run_acir_tests_browser.sh double_verify_proof + RUN BROWSER=chrome THREAD_MODEL=mt ./run_acir_tests_browser.sh 6_array # Run 1_mul through bb.js on chrome/webkit testing single threaded browser support. RUN BROWSER=chrome THREAD_MODEL=st ./run_acir_tests_browser.sh 1_mul # Commenting for now as fails intermittently. Unreproducable on mainframe. # See https://github.com/AztecProtocol/aztec-packages/issues/2104 #RUN BROWSER=webkit THREAD_MODEL=st ./run_acir_tests_browser.sh 1_mul + \ No newline at end of file diff --git a/acir_tests/Dockerfile.bb.js b/acir_tests/Dockerfile.bb.js index 1de51a791..e485ba86b 100644 --- a/acir_tests/Dockerfile.bb.js +++ b/acir_tests/Dockerfile.bb.js @@ -13,8 +13,10 @@ RUN cd browser-test-app && yarn && yarn build RUN cd headless-test && yarn && npx playwright install && npx playwright install-deps COPY . . ENV VERBOSE=1 -# Run double_verify_proof through bb.js on node to check 512k support. -RUN BIN=../ts/dest/node/main.js FLOW=prove_then_verify ./run_acir_tests.sh double_verify_proof +# TODO(https://github.com/noir-lang/noir/issues/5106) +# TODO(https://github.com/AztecProtocol/aztec-packages/issues/6672) +# Run ecdsa_secp256r1_3x through bb.js on node to check 256k support. +RUN BIN=../ts/dest/node/main.js FLOW=prove_then_verify ./run_acir_tests.sh ecdsa_secp256r1_3x # Run a single arbitrary test not involving recursion through bb.js for UltraHonk RUN BIN=../ts/dest/node/main.js FLOW=prove_then_verify_ultra_honk ./run_acir_tests.sh nested_array_dynamic # Run a single arbitrary test not involving recursion through bb.js for Plonk @@ -27,9 +29,10 @@ RUN BIN=../ts/dest/node/main.js FLOW=prove_and_verify_mega_honk ./run_acir_tests RUN BIN=../ts/dest/node/main.js FLOW=prove_and_verify_goblin ./run_acir_tests.sh 6_array # Run 1_mul through bb.js build, all_cmds flow, to test all cli args. RUN BIN=../ts/dest/node/main.js FLOW=all_cmds ./run_acir_tests.sh 1_mul -# Run double_verify_proof through bb.js on chrome testing multi-threaded browser support. +# TODO(https://github.com/AztecProtocol/aztec-packages/issues/6672) +# Run 6_array through bb.js on chrome testing multi-threaded browser support. # TODO: Currently headless webkit doesn't seem to have shared memory so skipping multi-threaded test. -RUN BROWSER=chrome THREAD_MODEL=mt ./run_acir_tests_browser.sh double_verify_proof +RUN BROWSER=chrome THREAD_MODEL=mt ./run_acir_tests_browser.sh 6_array # Run 1_mul through bb.js on chrome/webkit testing single threaded browser support. RUN BROWSER=chrome THREAD_MODEL=st ./run_acir_tests_browser.sh 1_mul # Commenting for now as fails intermittently. Unreproducable on mainframe. diff --git a/cpp/src/barretenberg/dsl/acir_format/acir_integration.test.cpp b/cpp/src/barretenberg/dsl/acir_format/acir_integration.test.cpp index e54a29eb2..1c2cb9286 100644 --- a/cpp/src/barretenberg/dsl/acir_format/acir_integration.test.cpp +++ b/cpp/src/barretenberg/dsl/acir_format/acir_integration.test.cpp @@ -6,6 +6,8 @@ #include #include +// #define LOG_SIZES + class AcirIntegrationTest : public ::testing::Test { public: static std::vector get_bytecode(const std::string& bytecodePath) @@ -53,17 +55,37 @@ class AcirIntegrationTest : public ::testing::Test { using VerificationKey = Flavor::VerificationKey; Prover prover{ builder }; - // builder.blocks.summarize(); - // info("num gates = ", builder.get_num_gates()); - // info("total circuit size = ", builder.get_total_circuit_size()); - // info("circuit size = ", prover.instance->proving_key.circuit_size); - // info("log circuit size = ", prover.instance->proving_key.log_circuit_size); +#ifdef LOG_SIZES + builder.blocks.summarize(); + info("num gates = ", builder.get_num_gates()); + info("total circuit size = ", builder.get_total_circuit_size()); + info("circuit size = ", prover.instance->proving_key.circuit_size); + info("log circuit size = ", prover.instance->proving_key.log_circuit_size); +#endif auto proof = prover.construct_proof(); - // Verify Honk proof auto verification_key = std::make_shared(prover.instance->proving_key); Verifier verifier{ verification_key }; + return verifier.verify_proof(proof); + } + template bool prove_and_verify_plonk(Flavor::CircuitBuilder& builder) + { + plonk::UltraComposer composer; + + auto prover = composer.create_prover(builder); +#ifdef LOG_SIZES + // builder.blocks.summarize(); + // info("num gates = ", builder.get_num_gates()); + // info("total circuit size = ", builder.get_total_circuit_size()); +#endif + auto proof = prover.construct_proof(); +#ifdef LOG_SIZES + // info("circuit size = ", prover.circuit_size); + // info("log circuit size = ", numeric::get_msb(prover.circuit_size)); +#endif + // Verify Plonk proof + auto verifier = composer.create_verifier(builder); return verifier.verify_proof(proof); } }; @@ -81,6 +103,7 @@ class AcirIntegrationFoldingTest : public AcirIntegrationTest, public testing::W TEST_P(AcirIntegrationSingleTest, ProveAndVerifyProgram) { using Flavor = MegaFlavor; + // using Flavor = bb::plonk::flavor::Ultra; using Builder = Flavor::CircuitBuilder; std::string test_name = GetParam(); @@ -91,7 +114,11 @@ TEST_P(AcirIntegrationSingleTest, ProveAndVerifyProgram) Builder builder = acir_format::create_circuit(acir_program.constraints, 0, acir_program.witness); // Construct and verify Honk proof - EXPECT_TRUE(prove_and_verify_honk(builder)); + if constexpr (IsPlonkFlavor) { + EXPECT_TRUE(prove_and_verify_plonk(builder)); + } else { + EXPECT_TRUE(prove_and_verify_honk(builder)); + } } // TODO(https://github.com/AztecProtocol/barretenberg/issues/994): Run all tests @@ -195,6 +222,7 @@ INSTANTIATE_TEST_SUITE_P(AcirTests, "double_verify_proof_recursive", "ecdsa_secp256k1", "ecdsa_secp256r1", + "ecdsa_secp256r1_3x", "eddsa", "embedded_curve_ops", "field_attribute", diff --git a/cpp/src/barretenberg/goblin/mock_circuits.hpp b/cpp/src/barretenberg/goblin/mock_circuits.hpp index 0c7da0149..0fbdef620 100644 --- a/cpp/src/barretenberg/goblin/mock_circuits.hpp +++ b/cpp/src/barretenberg/goblin/mock_circuits.hpp @@ -58,11 +58,11 @@ class GoblinMockCircuits { if (large) { stdlib::generate_sha256_test_circuit(builder, NUM_ITERATIONS_LARGE); - stdlib::generate_ecdsa_verification_test_circuit(builder, NUM_ITERATIONS_LARGE); + stdlib::generate_ecdsa_verification_test_circuit(builder, NUM_ITERATIONS_LARGE / 2); stdlib::generate_merkle_membership_test_circuit(builder, NUM_ITERATIONS_LARGE); } else { // Results in circuit size 2^17 when accumulated via ClientIvc stdlib::generate_sha256_test_circuit(builder, 5); - stdlib::generate_ecdsa_verification_test_circuit(builder, 2); + stdlib::generate_ecdsa_verification_test_circuit(builder, 1); stdlib::generate_merkle_membership_test_circuit(builder, 10); } @@ -153,7 +153,7 @@ class GoblinMockCircuits { { // Add operations representing general kernel logic e.g. state updates. Note: these are structured to make the // kernel "full" within the dyadic size 2^17 (130914 gates) - const size_t NUM_MERKLE_CHECKS = 45; + const size_t NUM_MERKLE_CHECKS = 40; const size_t NUM_ECDSA_VERIFICATIONS = 1; const size_t NUM_SHA_HASHES = 1; stdlib::generate_merkle_membership_test_circuit(builder, NUM_MERKLE_CHECKS); @@ -185,7 +185,7 @@ class GoblinMockCircuits { // Add operations representing general kernel logic e.g. state updates. Note: these are structured to make // the kernel "full" within the dyadic size 2^17 const size_t NUM_MERKLE_CHECKS = 20; - const size_t NUM_ECDSA_VERIFICATIONS = 2; + const size_t NUM_ECDSA_VERIFICATIONS = 1; const size_t NUM_SHA_HASHES = 1; stdlib::generate_merkle_membership_test_circuit(builder, NUM_MERKLE_CHECKS); stdlib::generate_ecdsa_verification_test_circuit(builder, NUM_ECDSA_VERIFICATIONS); diff --git a/cpp/src/barretenberg/goblin/mock_circuits_pinning.test.cpp b/cpp/src/barretenberg/goblin/mock_circuits_pinning.test.cpp index 3fe8894a7..40c42b10c 100644 --- a/cpp/src/barretenberg/goblin/mock_circuits_pinning.test.cpp +++ b/cpp/src/barretenberg/goblin/mock_circuits_pinning.test.cpp @@ -11,13 +11,13 @@ using namespace bb; * this, to the degree that matters for proof construction time, using these "pinning tests" that fix values. * */ -class MockCircuitsPinning : public ::testing::Test { +class MegaMockCircuitsPinning : public ::testing::Test { protected: using ProverInstance = ProverInstance_; static void SetUpTestSuite() { srs::init_crs_factory("../srs_db/ignition"); } }; -TEST_F(MockCircuitsPinning, FunctionSizes) +TEST_F(MegaMockCircuitsPinning, FunctionSizes) { const auto run_test = [](bool large) { Goblin goblin; @@ -34,7 +34,7 @@ TEST_F(MockCircuitsPinning, FunctionSizes) run_test(false); } -TEST_F(MockCircuitsPinning, RecursionKernelSizes) +TEST_F(MegaMockCircuitsPinning, RecursionKernelSizes) { const auto run_test = [](bool large) { { diff --git a/cpp/src/barretenberg/plonk_honk_shared/arithmetization/gate_data.hpp b/cpp/src/barretenberg/plonk_honk_shared/arithmetization/gate_data.hpp index 51c4a5584..b0c981ee7 100644 --- a/cpp/src/barretenberg/plonk_honk_shared/arithmetization/gate_data.hpp +++ b/cpp/src/barretenberg/plonk_honk_shared/arithmetization/gate_data.hpp @@ -69,6 +69,7 @@ struct ecc_op_tuple { uint32_t y_hi; uint32_t z_1; uint32_t z_2; + bool return_is_infinity; }; template inline void read(B& buf, poly_triple_& constraint) diff --git a/cpp/src/barretenberg/stdlib/plonk_recursion/verification_key/verification_key.test.cpp b/cpp/src/barretenberg/stdlib/plonk_recursion/verification_key/verification_key.test.cpp index 2f6969c62..0abc008cb 100644 --- a/cpp/src/barretenberg/stdlib/plonk_recursion/verification_key/verification_key.test.cpp +++ b/cpp/src/barretenberg/stdlib/plonk_recursion/verification_key/verification_key.test.cpp @@ -6,12 +6,13 @@ #include "barretenberg/stdlib_circuit_builders/standard_circuit_builder.hpp" #include "barretenberg/stdlib_circuit_builders/ultra_circuit_builder.hpp" +using namespace bb; +using namespace bb::plonk; + namespace { auto& engine = numeric::get_debug_randomness(); } // namespace -using namespace bb::plonk; - /** * @brief A test fixture that will let us generate VK data and run tests * for all builder types diff --git a/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield.hpp b/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield.hpp index e015988c5..3c49fb1c3 100644 --- a/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield.hpp +++ b/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield.hpp @@ -246,6 +246,12 @@ template class bigfield { bigfield conditional_negate(const bool_t& predicate) const; bigfield conditional_select(const bigfield& other, const bool_t& predicate) const; + static bigfield conditional_assign(const bool_t& predicate, const bigfield& lhs, const bigfield& rhs) + { + return rhs.conditional_select(lhs, predicate); + } + + bool_t operator==(const bigfield& other) const; void assert_is_in_field() const; void assert_less_than(const uint256_t upper_limit) const; diff --git a/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield.test.cpp b/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield.test.cpp index 695fe17ce..c996d7f30 100644 --- a/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield.test.cpp +++ b/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield.test.cpp @@ -840,6 +840,45 @@ template class stdlib_bigfield : public testing::Test { fq_ct ret = fq_ct::div_check_denominator_nonzero({}, a_ct); EXPECT_NE(ret.get_context(), nullptr); } + + static void test_assert_equal_not_equal() + { + auto builder = Builder(); + size_t num_repetitions = 10; + for (size_t i = 0; i < num_repetitions; ++i) { + fq inputs[4]{ fq::random_element(), fq::random_element(), fq::random_element(), fq::random_element() }; + + fq_ct a(witness_ct(&builder, fr(uint256_t(inputs[0]).slice(0, fq_ct::NUM_LIMB_BITS * 2))), + witness_ct(&builder, + fr(uint256_t(inputs[0]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4)))); + fq_ct b(witness_ct(&builder, fr(uint256_t(inputs[1]).slice(0, fq_ct::NUM_LIMB_BITS * 2))), + witness_ct(&builder, + fr(uint256_t(inputs[1]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4)))); + fq_ct c(witness_ct(&builder, fr(uint256_t(inputs[2]).slice(0, fq_ct::NUM_LIMB_BITS * 2))), + witness_ct(&builder, + fr(uint256_t(inputs[2]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4)))); + fq_ct d(witness_ct(&builder, fr(uint256_t(inputs[3]).slice(0, fq_ct::NUM_LIMB_BITS * 2))), + witness_ct(&builder, + fr(uint256_t(inputs[3]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4)))); + + fq_ct two(witness_ct(&builder, fr(2)), + witness_ct(&builder, fr(0)), + witness_ct(&builder, fr(0)), + witness_ct(&builder, fr(0))); + fq_ct t0 = a + a; + fq_ct t1 = a * two; + + t0.assert_equal(t1); + t0.assert_is_not_equal(c); + t0.assert_is_not_equal(d); + stdlib::bool_t is_equal_a = t0 == t1; + stdlib::bool_t is_equal_b = t0 == c; + EXPECT_TRUE(is_equal_a.get_value()); + EXPECT_FALSE(is_equal_b.get_value()); + } + bool result = CircuitChecker::check(builder); + EXPECT_EQ(result, true); + } }; // Define types for which the above tests will be constructed. @@ -929,6 +968,11 @@ TYPED_TEST(stdlib_bigfield, division_context) TestFixture::test_division_context(); } +TYPED_TEST(stdlib_bigfield, assert_equal_not_equal) +{ + TestFixture::test_assert_equal_not_equal(); +} + // // This test was disabled before the refactor to use TYPED_TEST's/ // TEST(stdlib_bigfield, DISABLED_test_div_against_constants) // { diff --git a/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield_impl.hpp b/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield_impl.hpp index e6358fd95..c507f4192 100644 --- a/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield_impl.hpp +++ b/cpp/src/barretenberg/stdlib/primitives/bigfield/bigfield_impl.hpp @@ -1568,6 +1568,63 @@ bigfield bigfield::conditional_select(const bigfield& ot return result; } +/** + * @brief Validate whether two bigfield elements are equal to each other + * @details To evaluate whether `(a == b)`, we use result boolean `r` to evaluate the following logic: + * (n.b all algebra involving bigfield elements is done in the bigfield) + * 1. If `r == 1` , `a - b == 0` + * 2. If `r == 0`, `a - b` posesses an inverse `I` i.e. `(a - b) * I - 1 == 0` + * We efficiently evaluate this logic by evaluating a single expression `(a - b)*X = Y` + * We use conditional assignment logic to define `X, Y` to be the following: + * If `r == 1` then `X = 1, Y = 0` + * If `r == 0` then `X = I, Y = 1` + * This allows us to evaluate `operator==` using only 1 bigfield multiplication operation. + * We can check the product equals 0 or 1 by directly evaluating the binary basis/prime basis limbs of Y. + * i.e. if `r == 1` then `(a - b)*X` should have 0 for all limb values + * if `r == 0` then `(a - b)*X` should have 1 in the least significant binary basis limb and 0 elsewhere + * @tparam Builder + * @tparam T + * @param other + * @return bool_t + */ +template bool_t bigfield::operator==(const bigfield& other) const +{ + Builder* ctx = context ? context : other.get_context(); + auto lhs = get_value() % modulus_u512; + auto rhs = other.get_value() % modulus_u512; + bool is_equal_raw = (lhs == rhs); + if (!ctx) { + // TODO(https://github.com/AztecProtocol/barretenberg/issues/660): null context _should_ mean that both are + // constant, but we check with an assertion to be sure. + ASSERT(is_constant() == other.is_constant()); + return is_equal_raw; + } + bool_t is_equal = witness_t(ctx, is_equal_raw); + + bigfield diff = (*this) - other; + + // TODO(https://github.com/AztecProtocol/barretenberg/issues/999): get native values efficiently (i.e. if u512 value + // fits in a u256, subtract off modulus until u256 fits into finite field) + native diff_native = native((diff.get_value() % modulus_u512).lo); + native inverse_native = is_equal_raw ? 0 : diff_native.invert(); + + bigfield inverse = bigfield::from_witness(ctx, inverse_native); + + bigfield multiplicand = bigfield::conditional_assign(is_equal, one(), inverse); + + bigfield product = diff * multiplicand; + + field_t result = field_t::conditional_assign(is_equal, 0, 1); + + product.prime_basis_limb.assert_equal(result); + product.binary_basis_limbs[0].element.assert_equal(result); + product.binary_basis_limbs[1].element.assert_equal(0); + product.binary_basis_limbs[2].element.assert_equal(0); + product.binary_basis_limbs[3].element.assert_equal(0); + + return is_equal; +} + /** * REDUCTION CHECK * @@ -1767,6 +1824,7 @@ template void bigfield::assert_equal( << std::endl; return; } else if (other.is_constant()) { + // TODO(https://github.com/AztecProtocol/barretenberg/issues/998): Something is fishy here // evaluate a strict equality - make sure *this is reduced first, or an honest prover // might not be able to satisfy these constraints. field_t t0 = (binary_basis_limbs[0].element - other.binary_basis_limbs[0].element); @@ -1783,24 +1841,47 @@ template void bigfield::assert_equal( } else if (is_constant()) { other.assert_equal(*this); return; - } + } else { + if (is_constant() && other.is_constant()) { + std::cerr << "bigfield: calling assert equal on 2 CONSTANT bigfield elements...is this intended?" + << std::endl; + return; + } else if (other.is_constant()) { + // evaluate a strict equality - make sure *this is reduced first, or an honest prover + // might not be able to satisfy these constraints. + field_t t0 = (binary_basis_limbs[0].element - other.binary_basis_limbs[0].element); + field_t t1 = (binary_basis_limbs[1].element - other.binary_basis_limbs[1].element); + field_t t2 = (binary_basis_limbs[2].element - other.binary_basis_limbs[2].element); + field_t t3 = (binary_basis_limbs[3].element - other.binary_basis_limbs[3].element); + field_t t4 = (prime_basis_limb - other.prime_basis_limb); + t0.assert_is_zero(); + t1.assert_is_zero(); + t2.assert_is_zero(); + t3.assert_is_zero(); + t4.assert_is_zero(); + return; + } else if (is_constant()) { + other.assert_equal(*this); + return; + } - bigfield diff = *this - other; - const uint512_t diff_val = diff.get_value(); - const uint512_t modulus(target_basis.modulus); - - const auto [quotient_512, remainder_512] = (diff_val).divmod(modulus); - if (remainder_512 != 0) - std::cerr << "bigfield: remainder not zero!" << std::endl; - ASSERT(remainder_512 == 0); - bigfield quotient; - - const size_t num_quotient_bits = get_quotient_max_bits({ 0 }); - quotient = bigfield(witness_t(ctx, fr(quotient_512.slice(0, NUM_LIMB_BITS * 2).lo)), - witness_t(ctx, fr(quotient_512.slice(NUM_LIMB_BITS * 2, NUM_LIMB_BITS * 4).lo)), - false, - num_quotient_bits); - unsafe_evaluate_multiply_add(diff, { one() }, {}, quotient, { zero() }); + bigfield diff = *this - other; + const uint512_t diff_val = diff.get_value(); + const uint512_t modulus(target_basis.modulus); + + const auto [quotient_512, remainder_512] = (diff_val).divmod(modulus); + if (remainder_512 != 0) + std::cerr << "bigfield: remainder not zero!" << std::endl; + ASSERT(remainder_512 == 0); + bigfield quotient; + + const size_t num_quotient_bits = get_quotient_max_bits({ 0 }); + quotient = bigfield(witness_t(ctx, fr(quotient_512.slice(0, NUM_LIMB_BITS * 2).lo)), + witness_t(ctx, fr(quotient_512.slice(NUM_LIMB_BITS * 2, NUM_LIMB_BITS * 4).lo)), + false, + num_quotient_bits); + unsafe_evaluate_multiply_add(diff, { one() }, {}, quotient, { zero() }); + } } } diff --git a/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.hpp b/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.hpp index f069379cf..0e0ab416c 100644 --- a/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.hpp +++ b/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.hpp @@ -21,6 +21,8 @@ namespace bb::stdlib { // ( ͡° ͜ʖ ͡°) template class element { public: + using bool_ct = stdlib::bool_t; + struct secp256k1_wnaf { std::vector> wnaf; field_t positive_skew; @@ -38,13 +40,23 @@ template class element { element(const Fq& x, const Fq& y); element(const element& other); - element(element&& other); + element(element&& other) noexcept; static element from_witness(Builder* ctx, const typename NativeGroup::affine_element& input) { - Fq x = Fq::from_witness(ctx, input.x); - Fq y = Fq::from_witness(ctx, input.y); - element out(x, y); + element out; + if (input.is_point_at_infinity()) { + Fq x = Fq::from_witness(ctx, NativeGroup::affine_one.x); + Fq y = Fq::from_witness(ctx, NativeGroup::affine_one.y); + out.x = x; + out.y = y; + } else { + Fq x = Fq::from_witness(ctx, input.x); + Fq y = Fq::from_witness(ctx, input.y); + out.x = x; + out.y = y; + } + out.set_point_at_infinity(witness_t(ctx, input.is_point_at_infinity())); out.validate_on_curve(); return out; } @@ -52,13 +64,17 @@ template class element { void validate_on_curve() const { Fq b(get_context(), uint256_t(NativeGroup::curve_b)); + Fq _b = Fq::conditional_assign(is_point_at_infinity(), Fq::zero(), b); + Fq _x = Fq::conditional_assign(is_point_at_infinity(), Fq::zero(), x); + Fq _y = Fq::conditional_assign(is_point_at_infinity(), Fq::zero(), y); if constexpr (!NativeGroup::has_a) { // we validate y^2 = x^3 + b by setting "fix_remainder_zero = true" when calling mult_madd - Fq::mult_madd({ x.sqr(), y }, { x, -y }, { b }, true); + Fq::mult_madd({ _x.sqr(), _y }, { _x, -_y }, { _b }, true); } else { Fq a(get_context(), uint256_t(NativeGroup::curve_a)); + Fq _a = Fq::conditional_assign(is_point_at_infinity(), Fq::zero(), a); // we validate y^2 = x^3 + ax + b by setting "fix_remainder_zero = true" when calling mult_madd - Fq::mult_madd({ x.sqr(), x, y }, { x, a, -y }, { b }, true); + Fq::mult_madd({ _x.sqr(), _x, _y }, { _x, _a, -_y }, { _b }, true); } } @@ -72,7 +88,7 @@ template class element { } element& operator=(const element& other); - element& operator=(element&& other); + element& operator=(element&& other) noexcept; byte_array to_byte_array() const { @@ -82,6 +98,9 @@ template class element { return result; } + element checked_unconditional_add(const element& other) const; + element checked_unconditional_subtract(const element& other) const; + element operator+(const element& other) const; element operator-(const element& other) const; element operator-() const @@ -100,11 +119,11 @@ template class element { *this = *this - other; return *this; } - std::array add_sub(const element& other) const; + std::array checked_unconditional_add_sub(const element& other) const; element operator*(const Fr& other) const; - element conditional_negate(const bool_t& predicate) const + element conditional_negate(const bool_ct& predicate) const { element result(*this); result.y = result.y.conditional_negate(predicate); @@ -176,11 +195,18 @@ template class element { typename NativeGroup::affine_element get_value() const { - uint512_t x_val = x.get_value(); - uint512_t y_val = y.get_value(); - return typename NativeGroup::affine_element(x_val.lo, y_val.lo); + uint512_t x_val = x.get_value() % Fq::modulus_u512; + uint512_t y_val = y.get_value() % Fq::modulus_u512; + auto result = typename NativeGroup::affine_element(x_val.lo, y_val.lo); + if (is_point_at_infinity().get_value()) { + result.self_set_infinity(); + } + return result; } + static std::pair, std::vector> handle_points_at_infinity( + const std::vector& _points, const std::vector& _scalars); + // compute a multi-scalar-multiplication by creating a precomputed lookup table for each point, // splitting each scalar multiplier up into a 4-bit sliding window wNAF. // more efficient than batch_mul if num_points < 4 @@ -229,7 +255,7 @@ template class element { template ::value>> static element secp256k1_ecdsa_mul(const element& pubkey, const Fr& u1, const Fr& u2); - static std::vector> compute_naf(const Fr& scalar, const size_t max_num_bits = 0); + static std::vector compute_naf(const Fr& scalar, const size_t max_num_bits = 0); template static std::vector> compute_wnaf(const Fr& scalar); @@ -265,10 +291,15 @@ template class element { return nullptr; } + bool_ct is_point_at_infinity() const { return _is_infinity; } + void set_point_at_infinity(const bool_ct& is_infinity) { _is_infinity = is_infinity; } + Fq x; Fq y; private: + bool_ct _is_infinity; + template >> static std::array, 5> create_group_element_rom_tables( const std::array& elements, std::array& limb_max); @@ -367,7 +398,7 @@ template class element { lookup_table_base(const lookup_table_base& other) = default; lookup_table_base& operator=(const lookup_table_base& other) = default; - element get(const std::array, length>& bits) const; + element get(const std::array& bits) const; element operator[](const size_t idx) const { return element_table[idx]; } @@ -397,7 +428,7 @@ template class element { lookup_table_plookup(const lookup_table_plookup& other) = default; lookup_table_plookup& operator=(const lookup_table_plookup& other) = default; - element get(const std::array, length>& bits) const; + element get(const std::array& bits) const; element operator[](const size_t idx) const { return element_table[idx]; } @@ -608,7 +639,7 @@ template class element { return chain_add_accumulator(add_accumulator[0]); } - element::chain_add_accumulator get_chain_add_accumulator(std::vector>& naf_entries) const + element::chain_add_accumulator get_chain_add_accumulator(std::vector& naf_entries) const { std::vector round_accumulator; for (size_t j = 0; j < num_sixes; ++j) { @@ -660,7 +691,7 @@ template class element { return (accumulator); } - element get(std::vector>& naf_entries) const + element get(std::vector& naf_entries) const { std::vector round_accumulator; for (size_t j = 0; j < num_sixes; ++j) { @@ -812,21 +843,21 @@ template class element { return chain_add_accumulator(add_accumulator[0]); } - element::chain_add_accumulator get_chain_add_accumulator(std::vector>& naf_entries) const + element::chain_add_accumulator get_chain_add_accumulator(std::vector& naf_entries) const { std::vector round_accumulator; for (size_t j = 0; j < num_quads; ++j) { - round_accumulator.push_back(quad_tables[j].get(std::array, 4>{ + round_accumulator.push_back(quad_tables[j].get(std::array{ naf_entries[4 * j], naf_entries[4 * j + 1], naf_entries[4 * j + 2], naf_entries[4 * j + 3] })); } if (has_triple) { - round_accumulator.push_back(triple_tables[0].get(std::array, 3>{ + round_accumulator.push_back(triple_tables[0].get(std::array{ naf_entries[num_quads * 4], naf_entries[num_quads * 4 + 1], naf_entries[num_quads * 4 + 2] })); } if (has_twin) { round_accumulator.push_back(twin_tables[0].get( - std::array, 2>{ naf_entries[num_quads * 4], naf_entries[num_quads * 4 + 1] })); + std::array{ naf_entries[num_quads * 4], naf_entries[num_quads * 4 + 1] })); } if (has_singleton) { round_accumulator.push_back(singletons[0].conditional_negate(naf_entries[num_points - 1])); @@ -849,7 +880,7 @@ template class element { return (accumulator); } - element get(std::vector>& naf_entries) const + element get(std::vector& naf_entries) const { std::vector round_accumulator; for (size_t j = 0; j < num_quads; ++j) { @@ -858,7 +889,7 @@ template class element { } if (has_triple) { - round_accumulator.push_back(triple_tables[0].get(std::array, 3>{ + round_accumulator.push_back(triple_tables[0].get(std::array{ naf_entries[num_quads * 4], naf_entries[num_quads * 4 + 1], naf_entries[num_quads * 4 + 2] })); } if (has_twin) { diff --git a/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.test.cpp b/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.test.cpp index 7071583ff..1069e722b 100644 --- a/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.test.cpp +++ b/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.test.cpp @@ -10,12 +10,12 @@ #include "barretenberg/stdlib/primitives/curves/secp256k1.hpp" #include "barretenberg/stdlib/primitives/curves/secp256r1.hpp" +using namespace bb; + namespace { auto& engine = numeric::get_debug_randomness(); } -using namespace bb; - // One can only define a TYPED_TEST with a single template paramter. // Our workaround is to pass parameters of the following type. template struct TestType { @@ -41,6 +41,8 @@ template class stdlib_biggroup : public testing::Test { using element = typename g1::element; using Builder = typename Curve::Builder; + using witness_ct = stdlib::witness_t; + using bool_ct = stdlib::bool_t; static constexpr auto EXPECT_CIRCUIT_CORRECTNESS = [](Builder& builder, bool expected_result = true) { info("num gates = ", builder.get_num_gates()); @@ -82,6 +84,45 @@ template class stdlib_biggroup : public testing::Test { EXPECT_CIRCUIT_CORRECTNESS(builder); } + static void test_add_points_at_infinity() + { + Builder builder; + size_t num_repetitions = 1; + for (size_t i = 0; i < num_repetitions; ++i) { + affine_element input_a(element::random_element()); + affine_element input_b(element::random_element()); + input_b.self_set_infinity(); + element_ct a = element_ct::from_witness(&builder, input_a); + // create copy of a with different witness + element_ct a_alternate = element_ct::from_witness(&builder, input_a); + element_ct a_negated = element_ct::from_witness(&builder, -input_a); + element_ct b = element_ct::from_witness(&builder, input_b); + + element_ct c = a + b; + element_ct d = b + a; + element_ct e = b + b; + element_ct f = a + a; + element_ct g = a + a_alternate; + element_ct h = a + a_negated; + + affine_element c_expected = affine_element(element(input_a) + element(input_b)); + affine_element d_expected = affine_element(element(input_b) + element(input_a)); + affine_element e_expected = affine_element(element(input_b) + element(input_b)); + affine_element f_expected = affine_element(element(input_a) + element(input_a)); + affine_element g_expected = affine_element(element(input_a) + element(input_a)); + affine_element h_expected = affine_element(element(input_a) + element(-input_a)); + + EXPECT_EQ(c.get_value(), c_expected); + EXPECT_EQ(d.get_value(), d_expected); + EXPECT_EQ(e.get_value(), e_expected); + EXPECT_EQ(f.get_value(), f_expected); + EXPECT_EQ(g.get_value(), g_expected); + EXPECT_EQ(h.get_value(), h_expected); + } + + EXPECT_CIRCUIT_CORRECTNESS(builder); + } + static void test_sub() { Builder builder; @@ -110,6 +151,45 @@ template class stdlib_biggroup : public testing::Test { EXPECT_CIRCUIT_CORRECTNESS(builder); } + static void test_sub_points_at_infinity() + { + Builder builder; + size_t num_repetitions = 1; + for (size_t i = 0; i < num_repetitions; ++i) { + affine_element input_a(element::random_element()); + affine_element input_b(element::random_element()); + input_b.self_set_infinity(); + element_ct a = element_ct::from_witness(&builder, input_a); + // create copy of a with different witness + element_ct a_alternate = element_ct::from_witness(&builder, input_a); + element_ct a_negated = element_ct::from_witness(&builder, -input_a); + element_ct b = element_ct::from_witness(&builder, input_b); + + element_ct c = a - b; + element_ct d = b - a; + element_ct e = b - b; + element_ct f = a - a; + element_ct g = a - a_alternate; + element_ct h = a - a_negated; + + affine_element c_expected = affine_element(element(input_a) - element(input_b)); + affine_element d_expected = affine_element(element(input_b) - element(input_a)); + affine_element e_expected = affine_element(element(input_b) - element(input_b)); + affine_element f_expected = affine_element(element(input_a) - element(input_a)); + affine_element g_expected = affine_element(element(input_a) - element(input_a)); + affine_element h_expected = affine_element(element(input_a) - element(-input_a)); + + EXPECT_EQ(c.get_value(), c_expected); + EXPECT_EQ(d.get_value(), d_expected); + EXPECT_EQ(e.get_value(), e_expected); + EXPECT_EQ(f.get_value(), f_expected); + EXPECT_EQ(g.get_value(), g_expected); + EXPECT_EQ(h.get_value(), h_expected); + } + + EXPECT_CIRCUIT_CORRECTNESS(builder); + } + static void test_dbl() { Builder builder; @@ -369,6 +449,107 @@ template class stdlib_biggroup : public testing::Test { EXPECT_CIRCUIT_CORRECTNESS(builder); } + static void test_batch_mul_edge_cases() + { + { + // batch P + P = 2P + std::vector points; + points.push_back(affine_element::one()); + points.push_back(affine_element::one()); + std::vector scalars; + scalars.push_back(1); + scalars.push_back(1); + + Builder builder; + ASSERT(points.size() == scalars.size()); + const size_t num_points = points.size(); + + std::vector circuit_points; + std::vector circuit_scalars; + for (size_t i = 0; i < num_points; ++i) { + circuit_points.push_back(element_ct::from_witness(&builder, points[i])); + circuit_scalars.push_back(scalar_ct::from_witness(&builder, scalars[i])); + } + element_ct result_point = element_ct::batch_mul(circuit_points, circuit_scalars); + + element expected_point = points[0] + points[1]; + expected_point = expected_point.normalize(); + + fq result_x(result_point.x.get_value().lo); + fq result_y(result_point.y.get_value().lo); + + EXPECT_EQ(result_x, expected_point.x); + EXPECT_EQ(result_y, expected_point.y); + + EXPECT_CIRCUIT_CORRECTNESS(builder); + } + { + // batch oo + P = P + std::vector points; + points.push_back(affine_element::infinity()); + points.push_back(affine_element(element::random_element())); + std::vector scalars; + scalars.push_back(1); + scalars.push_back(1); + + Builder builder; + ASSERT(points.size() == scalars.size()); + const size_t num_points = points.size(); + + std::vector circuit_points; + std::vector circuit_scalars; + for (size_t i = 0; i < num_points; ++i) { + circuit_points.push_back(element_ct::from_witness(&builder, points[i])); + circuit_scalars.push_back(scalar_ct::from_witness(&builder, scalars[i])); + } + element_ct result_point = element_ct::batch_mul(circuit_points, circuit_scalars); + + element expected_point = points[1]; + expected_point = expected_point.normalize(); + + fq result_x(result_point.x.get_value().lo); + fq result_y(result_point.y.get_value().lo); + + EXPECT_EQ(result_x, expected_point.x); + EXPECT_EQ(result_y, expected_point.y); + + EXPECT_CIRCUIT_CORRECTNESS(builder); + } + { + // batch 0 * P1 + P2 = P2 + std::vector points; + points.push_back(affine_element(element::random_element())); + points.push_back(affine_element(element::random_element())); + std::vector scalars; + scalars.push_back(0); + scalars.push_back(1); + + Builder builder; + ASSERT(points.size() == scalars.size()); + const size_t num_points = points.size(); + + std::vector circuit_points; + std::vector circuit_scalars; + for (size_t i = 0; i < num_points; ++i) { + circuit_points.push_back(element_ct::from_witness(&builder, points[i])); + circuit_scalars.push_back(scalar_ct::from_witness(&builder, scalars[i])); + } + + element_ct result_point = element_ct::batch_mul(circuit_points, circuit_scalars); + + element expected_point = points[1]; + expected_point = expected_point.normalize(); + + fq result_x(result_point.x.get_value().lo); + fq result_y(result_point.y.get_value().lo); + + EXPECT_EQ(result_x, expected_point.x); + EXPECT_EQ(result_y, expected_point.y); + + EXPECT_CIRCUIT_CORRECTNESS(builder); + } + } + static void test_chain_add() { Builder builder = Builder(); @@ -486,6 +667,107 @@ template class stdlib_biggroup : public testing::Test { EXPECT_CIRCUIT_CORRECTNESS(builder); } + static void test_wnaf_batch_mul_edge_cases() + { + { + // batch P + P = 2P + std::vector points; + points.push_back(affine_element::one()); + points.push_back(affine_element::one()); + std::vector scalars; + scalars.push_back(1); + scalars.push_back(1); + + Builder builder; + ASSERT(points.size() == scalars.size()); + const size_t num_points = points.size(); + + std::vector circuit_points; + std::vector circuit_scalars; + for (size_t i = 0; i < num_points; ++i) { + circuit_points.push_back(element_ct::from_witness(&builder, points[i])); + circuit_scalars.push_back(scalar_ct::from_witness(&builder, scalars[i])); + } + element_ct result_point = element_ct::wnaf_batch_mul(circuit_points, circuit_scalars); + + element expected_point = points[0] + points[1]; + expected_point = expected_point.normalize(); + + fq result_x(result_point.x.get_value().lo); + fq result_y(result_point.y.get_value().lo); + + EXPECT_EQ(result_x, expected_point.x); + EXPECT_EQ(result_y, expected_point.y); + + EXPECT_CIRCUIT_CORRECTNESS(builder); + } + { + // batch oo + P = P + std::vector points; + points.push_back(affine_element::infinity()); + points.push_back(affine_element(element::random_element())); + std::vector scalars; + scalars.push_back(1); + scalars.push_back(1); + + Builder builder; + ASSERT(points.size() == scalars.size()); + const size_t num_points = points.size(); + + std::vector circuit_points; + std::vector circuit_scalars; + for (size_t i = 0; i < num_points; ++i) { + circuit_points.push_back(element_ct::from_witness(&builder, points[i])); + circuit_scalars.push_back(scalar_ct::from_witness(&builder, scalars[i])); + } + element_ct result_point = element_ct::wnaf_batch_mul(circuit_points, circuit_scalars); + + element expected_point = points[1]; + expected_point = expected_point.normalize(); + + fq result_x(result_point.x.get_value().lo); + fq result_y(result_point.y.get_value().lo); + + EXPECT_EQ(result_x, expected_point.x); + EXPECT_EQ(result_y, expected_point.y); + + EXPECT_CIRCUIT_CORRECTNESS(builder); + } + { + // batch 0 * P1 + P2 = P2 + std::vector points; + points.push_back(affine_element(element::random_element())); + points.push_back(affine_element(element::random_element())); + std::vector scalars; + scalars.push_back(0); + scalars.push_back(1); + + Builder builder; + ASSERT(points.size() == scalars.size()); + const size_t num_points = points.size(); + + std::vector circuit_points; + std::vector circuit_scalars; + for (size_t i = 0; i < num_points; ++i) { + circuit_points.push_back(element_ct::from_witness(&builder, points[i])); + circuit_scalars.push_back(scalar_ct::from_witness(&builder, scalars[i])); + } + + element_ct result_point = element_ct::wnaf_batch_mul(circuit_points, circuit_scalars); + + element expected_point = points[1]; + expected_point = expected_point.normalize(); + + fq result_x(result_point.x.get_value().lo); + fq result_y(result_point.y.get_value().lo); + + EXPECT_EQ(result_x, expected_point.x); + EXPECT_EQ(result_y, expected_point.y); + + EXPECT_CIRCUIT_CORRECTNESS(builder); + } + } + static void test_batch_mul_short_scalars() { const size_t num_points = 11; @@ -834,10 +1116,19 @@ TYPED_TEST(stdlib_biggroup, add) TestFixture::test_add(); } +TYPED_TEST(stdlib_biggroup, add_points_at_infinity) +{ + TestFixture::test_add_points_at_infinity(); +} TYPED_TEST(stdlib_biggroup, sub) { TestFixture::test_sub(); } +TYPED_TEST(stdlib_biggroup, sub_points_at_infinity) +{ + + TestFixture::test_sub_points_at_infinity(); +} TYPED_TEST(stdlib_biggroup, dbl) { TestFixture::test_dbl(); @@ -886,6 +1177,14 @@ HEAVY_TYPED_TEST(stdlib_biggroup, batch_mul) { TestFixture::test_batch_mul(); } +HEAVY_TYPED_TEST(stdlib_biggroup, batch_mul_edge_cases) +{ + if constexpr (HasGoblinBuilder) { + TestFixture::test_batch_mul_edge_cases(); + } else { + GTEST_SKIP() << "https://github.com/AztecProtocol/barretenberg/issues/1000"; + }; +} HEAVY_TYPED_TEST(stdlib_biggroup, chain_add) { @@ -932,6 +1231,20 @@ HEAVY_TYPED_TEST(stdlib_biggroup, wnaf_batch_mul) } } +/* These tests only work for Ultra Circuit Constructor */ +HEAVY_TYPED_TEST(stdlib_biggroup, wnaf_batch_mul_edge_cases) +{ + if constexpr (HasPlookup) { + if constexpr (HasGoblinBuilder) { + GTEST_SKIP() << "https://github.com/AztecProtocol/barretenberg/issues/707"; + } else { + TestFixture::test_compute_wnaf(); + }; + } else { + GTEST_SKIP(); + } +} + /* the following test was only developed as a test of Ultra Circuit Constructor. It fails for Standard in the case where Fr is a bigfield. */ HEAVY_TYPED_TEST(stdlib_biggroup, compute_wnaf) diff --git a/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_batch_mul.hpp b/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_batch_mul.hpp index a10198286..e931e1c63 100644 --- a/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_batch_mul.hpp +++ b/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_batch_mul.hpp @@ -1,23 +1,29 @@ #pragma once +#include "barretenberg/stdlib/primitives/biggroup/handle_points_at_infinity.hpp" +#include namespace bb::stdlib { /** - * only works for Plookup (otherwise falls back on batch_mul)! Multiscalar multiplication that utilizes 4-bit wNAF - * lookup tables is more efficient than points-as-linear-combinations lookup tables, if the number of points is 3 or - * fewer + * @brief Multiscalar multiplication that utilizes 4-bit wNAF lookup tables. + * @details This is more efficient than points-as-linear-combinations lookup tables, if the number of points is 3 or + * fewer. Only works for Plookup (otherwise falls back on batch_mul)! + * @todo : TODO(https://github.com/AztecProtocol/barretenberg/issues/1001) when we nuke standard and turbo plonk we + * should remove the fallback batch mul method! */ template template -element element::wnaf_batch_mul(const std::vector& points, - const std::vector& scalars) +element element::wnaf_batch_mul(const std::vector& _points, + const std::vector& _scalars) { constexpr size_t WNAF_SIZE = 4; - ASSERT(points.size() == scalars.size()); + ASSERT(_points.size() == _scalars.size()); if constexpr (!HasPlookup) { - return batch_mul(points, scalars, max_num_bits); + return batch_mul(_points, _scalars, max_num_bits); } + const auto [points, scalars] = handle_points_at_infinity(_points, _scalars); + std::vector> point_tables; for (const auto& point : points) { point_tables.emplace_back(four_bit_table_plookup<>(point)); @@ -49,8 +55,8 @@ element element::wnaf_batch_mul(const std::vector(wnaf_entries[i][num_rounds])); - Fq out_y = accumulator.y.conditional_select(skew.y, bool_t(wnaf_entries[i][num_rounds])); + Fq out_x = accumulator.x.conditional_select(skew.x, bool_ct(wnaf_entries[i][num_rounds])); + Fq out_y = accumulator.y.conditional_select(skew.y, bool_ct(wnaf_entries[i][num_rounds])); accumulator = element(out_x, out_y); } accumulator -= offset_generators.second; diff --git a/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_bn254.hpp b/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_bn254.hpp index d82230eed..b6e7f887a 100644 --- a/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_bn254.hpp +++ b/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_bn254.hpp @@ -7,6 +7,8 @@ * We use a special case algorithm to split bn254 scalar multipliers into endomorphism scalars * **/ +#include "barretenberg/stdlib/primitives/biggroup/biggroup.hpp" +#include "barretenberg/stdlib/primitives/circuit_builders/circuit_builders.hpp" namespace bb::stdlib { /** @@ -54,9 +56,9 @@ element element::bn254_endo_batch_mul_with_generator auto& big_table = big_table_pair.first; auto& endo_table = big_table_pair.second; batch_lookup_table small_table(small_points); - std::vector>> big_naf_entries; - std::vector>> endo_naf_entries; - std::vector>> small_naf_entries; + std::vector> big_naf_entries; + std::vector> endo_naf_entries; + std::vector> small_naf_entries; const auto split_into_endomorphism_scalars = [ctx](const Fr& scalar) { bb::fr k = scalar.get_value(); @@ -99,9 +101,9 @@ element element::bn254_endo_batch_mul_with_generator element accumulator = element::chain_add_end(init_point); const auto get_point_to_add = [&](size_t naf_index) { - std::vector> small_nafs; - std::vector> big_nafs; - std::vector> endo_nafs; + std::vector small_nafs; + std::vector big_nafs; + std::vector endo_nafs; for (size_t i = 0; i < small_points.size(); ++i) { small_nafs.emplace_back(small_naf_entries[i][naf_index]); } @@ -178,16 +180,16 @@ element element::bn254_endo_batch_mul_with_generator } { element skew = accumulator - generator_table[128]; - Fq out_x = accumulator.x.conditional_select(skew.x, bool_t(generator_wnaf[generator_wnaf.size() - 1])); - Fq out_y = accumulator.y.conditional_select(skew.y, bool_t(generator_wnaf[generator_wnaf.size() - 1])); + Fq out_x = accumulator.x.conditional_select(skew.x, bool_ct(generator_wnaf[generator_wnaf.size() - 1])); + Fq out_y = accumulator.y.conditional_select(skew.y, bool_ct(generator_wnaf[generator_wnaf.size() - 1])); accumulator = element(out_x, out_y); } { element skew = accumulator - generator_endo_table[128]; Fq out_x = - accumulator.x.conditional_select(skew.x, bool_t(generator_endo_wnaf[generator_wnaf.size() - 1])); + accumulator.x.conditional_select(skew.x, bool_ct(generator_endo_wnaf[generator_wnaf.size() - 1])); Fq out_y = - accumulator.y.conditional_select(skew.y, bool_t(generator_endo_wnaf[generator_wnaf.size() - 1])); + accumulator.y.conditional_select(skew.y, bool_ct(generator_endo_wnaf[generator_wnaf.size() - 1])); accumulator = element(out_x, out_y); } @@ -320,7 +322,7 @@ element element::bn254_endo_batch_mul(const std::vec **/ const size_t num_rounds = max_num_small_bits; const size_t num_points = points.size(); - std::vector>> naf_entries; + std::vector> naf_entries; for (size_t i = 0; i < num_points; ++i) { naf_entries.emplace_back(compute_naf(scalars[i], max_num_small_bits)); } @@ -354,7 +356,7 @@ element element::bn254_endo_batch_mul(const std::vec **/ for (size_t i = 1; i < num_rounds / 2; ++i) { // `nafs` tracks the naf value for each point for the current round - std::vector> nafs; + std::vector nafs; for (size_t j = 0; j < points.size(); ++j) { nafs.emplace_back(naf_entries[j][i * 2 - 1]); } @@ -383,7 +385,7 @@ element element::bn254_endo_batch_mul(const std::vec // we need to iterate 1 more time if the number of rounds is even if ((num_rounds & 0x01ULL) == 0x00ULL) { - std::vector> nafs; + std::vector nafs; for (size_t j = 0; j < points.size(); ++j) { nafs.emplace_back(naf_entries[j][num_rounds - 1]); } diff --git a/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_goblin.hpp b/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_goblin.hpp index 62404fc05..ef0e0fcb4 100644 --- a/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_goblin.hpp +++ b/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_goblin.hpp @@ -1,5 +1,6 @@ #pragma once +#include "barretenberg/stdlib/primitives/biggroup/biggroup.hpp" namespace bb::stdlib { /** @@ -87,8 +88,12 @@ element element::goblin_batch_mul(const std::vector< auto y_hi = Fr::from_witness_index(builder, op_tuple.y_hi); Fq point_x(x_lo, x_hi); Fq point_y(y_lo, y_hi); + element result = element(point_x, point_y); + if (op_tuple.return_is_infinity) { + result.set_point_at_infinity(bool_ct(builder, true)); + }; - return element(point_x, point_y); + return result; } } // namespace bb::stdlib diff --git a/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_goblin.test.cpp b/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_goblin.test.cpp index 9f9772043..bca043bb0 100644 --- a/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_goblin.test.cpp +++ b/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_goblin.test.cpp @@ -10,12 +10,12 @@ #include "barretenberg/numeric/random/engine.hpp" #include +using namespace bb; + namespace { auto& engine = numeric::get_debug_randomness(); } -using namespace bb; - template class stdlib_biggroup_goblin : public testing::Test { using element_ct = typename Curve::Element; using scalar_ct = typename Curve::ScalarField; diff --git a/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_impl.hpp b/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_impl.hpp index 9d25f5874..3d2752aa2 100644 --- a/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_impl.hpp +++ b/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_impl.hpp @@ -2,8 +2,7 @@ #include "../bit_array/bit_array.hpp" #include "../circuit_builders/circuit_builders.hpp" - -using namespace bb; +#include "barretenberg/stdlib/primitives/biggroup/biggroup.hpp" namespace bb::stdlib { @@ -11,50 +10,184 @@ template element::element() : x() , y() + , _is_infinity() {} template element::element(const typename G::affine_element& input) : x(nullptr, input.x) , y(nullptr, input.y) + , _is_infinity(nullptr, input.is_point_at_infinity()) {} template element::element(const Fq& x_in, const Fq& y_in) : x(x_in) , y(y_in) + , _is_infinity(x.get_context() ? x.get_context() : y.get_context(), false) {} template element::element(const element& other) : x(other.x) , y(other.y) + , _is_infinity(other.is_point_at_infinity()) {} template -element::element(element&& other) +element::element(element&& other) noexcept : x(other.x) , y(other.y) + , _is_infinity(other.is_point_at_infinity()) {} template element& element::operator=(const element& other) { + if (&other == this) { + return *this; + } x = other.x; y = other.y; + _is_infinity = other.is_point_at_infinity(); return *this; } template -element& element::operator=(element&& other) +element& element::operator=(element&& other) noexcept { + if (&other == this) { + return *this; + } x = other.x; y = other.y; + _is_infinity = other.is_point_at_infinity(); return *this; } template element element::operator+(const element& other) const +{ + // return checked_unconditional_add(other); + if constexpr (IsMegaBuilder && std::same_as) { + // TODO(https://github.com/AztecProtocol/barretenberg/issues/707) Optimize + // Current gate count: 6398 + std::vector points{ *this, other }; + std::vector scalars{ 1, 1 }; + return batch_mul(points, scalars); + } + + // Adding in `x_coordinates_match` ensures that lambda will always be well-formed + // Our curve has the form y^2 = x^3 + b. + // If (x_1, y_1), (x_2, y_2) have x_1 == x_2, and the generic formula for lambda has a division by 0. + // Then y_1 == y_2 (i.e. we are doubling) or y_2 == y_1 (the sum is infinity). + // The cases have a special addition formula. The following booleans allow us to handle these cases uniformly. + const bool_ct x_coordinates_match = other.x == x; + const bool_ct y_coordinates_match = (y == other.y); + const bool_ct infinity_predicate = (x_coordinates_match && !y_coordinates_match); + const bool_ct double_predicate = (x_coordinates_match && y_coordinates_match); + const bool_ct lhs_infinity = is_point_at_infinity(); + const bool_ct rhs_infinity = other.is_point_at_infinity(); + + // Compute the gradient `lambda`. If we add, `lambda = (y2 - y1)/(x2 - x1)`, else `lambda = 3x1*x1/2y1 + const Fq add_lambda_numerator = other.y - y; + const Fq xx = x * x; + const Fq dbl_lambda_numerator = xx + xx + xx; + const Fq lambda_numerator = Fq::conditional_assign(double_predicate, dbl_lambda_numerator, add_lambda_numerator); + + const Fq add_lambda_denominator = other.x - x; + const Fq dbl_lambda_denominator = y + y; + Fq lambda_denominator = Fq::conditional_assign(double_predicate, dbl_lambda_denominator, add_lambda_denominator); + // If either inputs are points at infinity, we set lambda_denominator to be 1. This ensures we never trigger a + // divide by zero error. + // Note: if either inputs are points at infinity we will not use the result of this computation. + Fq safe_edgecase_denominator = Fq(field_t(1), field_t(0), field_t(0), field_t(0)); + lambda_denominator = Fq::conditional_assign( + lhs_infinity || rhs_infinity || infinity_predicate, safe_edgecase_denominator, lambda_denominator); + const Fq lambda = Fq::div_without_denominator_check({ lambda_numerator }, lambda_denominator); + + const Fq x3 = lambda.sqradd({ -other.x, -x }); + const Fq y3 = lambda.madd(x - x3, { -y }); + + element result(x3, y3); + // if lhs infinity, return rhs + result.x = Fq::conditional_assign(lhs_infinity, other.x, result.x); + result.y = Fq::conditional_assign(lhs_infinity, other.y, result.y); + // if rhs infinity, return lhs + result.x = Fq::conditional_assign(rhs_infinity, x, result.x); + result.y = Fq::conditional_assign(rhs_infinity, y, result.y); + + // is result point at infinity? + // yes = infinity_predicate && !lhs_infinity && !rhs_infinity + // yes = lhs_infinity && rhs_infinity + // n.b. can likely optimize this + bool_ct result_is_infinity = infinity_predicate && (!lhs_infinity && !rhs_infinity); + result_is_infinity = result_is_infinity || (lhs_infinity && rhs_infinity); + result.set_point_at_infinity(result_is_infinity); + return result; +} + +template +element element::operator-(const element& other) const +{ + // return checked_unconditional_add(other); + if constexpr (IsMegaBuilder && std::same_as) { + // TODO(https://github.com/AztecProtocol/barretenberg/issues/707) Optimize + // Current gate count: 6398 + std::vector points{ *this, other }; + std::vector scalars{ 1, -Fr(1) }; + return batch_mul(points, scalars); + } + + // if x_coordinates match, lambda triggers a divide by zero error. + // Adding in `x_coordinates_match` ensures that lambda will always be well-formed + const bool_ct x_coordinates_match = other.x == x; + const bool_ct y_coordinates_match = (y == other.y); + const bool_ct infinity_predicate = (x_coordinates_match && y_coordinates_match); + const bool_ct double_predicate = (x_coordinates_match && !y_coordinates_match); + const bool_ct lhs_infinity = is_point_at_infinity(); + const bool_ct rhs_infinity = other.is_point_at_infinity(); + + // Compute the gradient `lambda`. If we add, `lambda = (y2 - y1)/(x2 - x1)`, else `lambda = 3x1*x1/2y1 + const Fq add_lambda_numerator = -other.y - y; + const Fq xx = x * x; + const Fq dbl_lambda_numerator = xx + xx + xx; + const Fq lambda_numerator = Fq::conditional_assign(double_predicate, dbl_lambda_numerator, add_lambda_numerator); + + const Fq add_lambda_denominator = other.x - x; + const Fq dbl_lambda_denominator = y + y; + Fq lambda_denominator = Fq::conditional_assign(double_predicate, dbl_lambda_denominator, add_lambda_denominator); + // If either inputs are points at infinity, we set lambda_denominator to be 1. This ensures we never trigger a + // divide by zero error. + // (if either inputs are points at infinity we will not use the result of this computation) + Fq safe_edgecase_denominator = Fq(field_t(1), field_t(0), field_t(0), field_t(0)); + lambda_denominator = Fq::conditional_assign( + lhs_infinity || rhs_infinity || infinity_predicate, safe_edgecase_denominator, lambda_denominator); + const Fq lambda = Fq::div_without_denominator_check({ lambda_numerator }, lambda_denominator); + + const Fq x3 = lambda.sqradd({ -other.x, -x }); + const Fq y3 = lambda.madd(x - x3, { -y }); + + element result(x3, y3); + // if lhs infinity, return rhs + result.x = Fq::conditional_assign(lhs_infinity, other.x, result.x); + result.y = Fq::conditional_assign(lhs_infinity, -other.y, result.y); + // if rhs infinity, return lhs + result.x = Fq::conditional_assign(rhs_infinity, x, result.x); + result.y = Fq::conditional_assign(rhs_infinity, y, result.y); + + // is result point at infinity? + // yes = infinity_predicate && !lhs_infinity && !rhs_infinity + // yes = lhs_infinity && rhs_infinity + // n.b. can likely optimize this + bool_ct result_is_infinity = infinity_predicate && (!lhs_infinity && !rhs_infinity); + result_is_infinity = result_is_infinity || (lhs_infinity && rhs_infinity); + result.set_point_at_infinity(result_is_infinity); + return result; +} + +template +element element::checked_unconditional_add(const element& other) const { if constexpr (IsMegaBuilder && std::same_as) { // TODO(https://github.com/AztecProtocol/barretenberg/issues/707) Optimize @@ -72,7 +205,7 @@ element element::operator+(const element& other) con } template -element element::operator-(const element& other) const +element element::checked_unconditional_subtract(const element& other) const { if constexpr (IsMegaBuilder && std::same_as) { // TODO(https://github.com/AztecProtocol/barretenberg/issues/707) Optimize @@ -105,7 +238,7 @@ element element::operator-(const element& other) con */ // TODO(https://github.com/AztecProtocol/barretenberg/issues/657): This function is untested template -std::array, 2> element::add_sub(const element& other) const +std::array, 2> element::checked_unconditional_add_sub(const element& other) const { if constexpr (IsMegaBuilder && std::same_as) { return { *this + other, *this - other }; @@ -142,7 +275,9 @@ template element element Fq neg_lambda = Fq::msub_div({ x }, { (two_x + x) }, (y + y), {}); Fq x_3 = neg_lambda.sqradd({ -(two_x) }); Fq y_3 = neg_lambda.madd(x_3 - x, { -y }); - return element(x_3, y_3); + element result = element(x_3, y_3); + result.set_point_at_infinity(is_point_at_infinity()); + return result; } /** @@ -619,10 +754,12 @@ std::pair, element> element::c * scalars See `bn254_endo_batch_mul` for description of algorithm **/ template -element element::batch_mul(const std::vector& points, - const std::vector& scalars, +element element::batch_mul(const std::vector& _points, + const std::vector& _scalars, const size_t max_num_bits) { + const auto [points, scalars] = handle_points_at_infinity(_points, _scalars); + if constexpr (IsSimulator) { // TODO(https://github.com/AztecProtocol/barretenberg/issues/663) auto context = points[0].get_context(); @@ -639,13 +776,12 @@ element element::batch_mul(const std::vector && std::same_as) { return goblin_batch_mul(points, scalars); } else { - const size_t num_points = points.size(); ASSERT(scalars.size() == num_points); batch_lookup_table point_table(points); const size_t num_rounds = (max_num_bits == 0) ? Fr::modulus.get_msb() + 1 : max_num_bits; - std::vector>> naf_entries; + std::vector> naf_entries; for (size_t i = 0; i < num_points; ++i) { naf_entries.emplace_back(compute_naf(scalars[i], max_num_bits)); } @@ -660,7 +796,7 @@ element element::batch_mul(const std::vector> nafs(num_points); + std::vector nafs(num_points); std::vector to_add; const size_t inner_num_rounds = (i != num_iterations - 1) ? num_rounds_per_iteration : num_rounds_per_final_iteration; @@ -724,14 +860,14 @@ element element::operator*(const Fr& scalar) const } else { constexpr uint64_t num_rounds = Fr::modulus.get_msb() + 1; - std::vector> naf_entries = compute_naf(scalar); + std::vector naf_entries = compute_naf(scalar); const auto offset_generators = compute_offset_generators(num_rounds); element accumulator = *this + offset_generators.first; for (size_t i = 1; i < num_rounds; ++i) { - bool_t predicate = naf_entries[i]; + bool_ct predicate = naf_entries[i]; bigfield y_test = y.conditional_negate(predicate); element to_add(x, y_test); accumulator = accumulator.montgomery_ladder(to_add); diff --git a/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_nafs.hpp b/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_nafs.hpp index c82ad8daa..ad1663c46 100644 --- a/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_nafs.hpp +++ b/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_nafs.hpp @@ -1,5 +1,6 @@ #pragma once #include "barretenberg/ecc/curves/secp256k1/secp256k1.hpp" +#include "barretenberg/stdlib/primitives/biggroup/biggroup.hpp" namespace bb::stdlib { @@ -486,17 +487,17 @@ std::vector> element::compute_naf(const Fr& scalar, cons uint256_t scalar_multiplier = scalar_multiplier_512.lo; const size_t num_rounds = (max_num_bits == 0) ? Fr::modulus.get_msb() + 1 : max_num_bits; - std::vector> naf_entries(num_rounds + 1); + std::vector naf_entries(num_rounds + 1); // if boolean is false => do NOT flip y // if boolean is true => DO flip y // first entry is skew. i.e. do we subtract one from the final result or not if (scalar_multiplier.get_bit(0) == false) { // add skew - naf_entries[num_rounds] = bool_t(witness_t(ctx, true)); + naf_entries[num_rounds] = bool_ct(witness_t(ctx, true)); scalar_multiplier += uint256_t(1); } else { - naf_entries[num_rounds] = bool_t(witness_t(ctx, false)); + naf_entries[num_rounds] = bool_ct(witness_t(ctx, false)); } for (size_t i = 0; i < num_rounds - 1; ++i) { bool next_entry = scalar_multiplier.get_bit(i + 1); @@ -504,7 +505,7 @@ std::vector> element::compute_naf(const Fr& scalar, cons // This is a VERY hacky workaround to ensure that UltraPlonkBuilder will apply a basic // range constraint per bool, and not a full 1-bit range gate if (next_entry == false) { - bool_t bit(ctx, true); + bool_ct bit(ctx, true); bit.context = ctx; bit.witness_index = witness_t(ctx, true).witness_index; // flip sign bit.witness_bool = true; @@ -520,7 +521,7 @@ std::vector> element::compute_naf(const Fr& scalar, cons } naf_entries[num_rounds - i - 1] = bit; } else { - bool_t bit(ctx, false); + bool_ct bit(ctx, false); bit.witness_index = witness_t(ctx, false).witness_index; // don't flip sign bit.witness_bool = false; if constexpr (IsSimulator) { @@ -537,7 +538,7 @@ std::vector> element::compute_naf(const Fr& scalar, cons naf_entries[num_rounds - i - 1] = bit; } } - naf_entries[0] = bool_t(ctx, false); // most significant entry is always true + naf_entries[0] = bool_ct(ctx, false); // most significant entry is always true // validate correctness of NAF if constexpr (!Fr::is_composite) { @@ -554,7 +555,7 @@ std::vector> element::compute_naf(const Fr& scalar, cons Fr accumulator_result = Fr::accumulate(accumulators); scalar.assert_equal(accumulator_result); } else { - const auto reconstruct_half_naf = [](bool_t* nafs, const size_t half_round_length) { + const auto reconstruct_half_naf = [](bool_ct* nafs, const size_t half_round_length) { // Q: need constraint to start from zero? field_t negative_accumulator(0); field_t positive_accumulator(0); diff --git a/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_secp256k1.hpp b/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_secp256k1.hpp index 6f898f6a2..15ff2a9af 100644 --- a/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_secp256k1.hpp +++ b/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_secp256k1.hpp @@ -5,6 +5,7 @@ * TODO: we should try to genericize this, but this method is super fiddly and we need it to be efficient! * **/ +#include "barretenberg/stdlib/primitives/biggroup/biggroup.hpp" namespace bb::stdlib { template @@ -119,14 +120,14 @@ element element::secp256k1_ecdsa_mul(const element& const element& base_point, const field_t& positive_skew, const field_t& negative_skew) { - const bool_t positive_skew_bool(positive_skew); - const bool_t negative_skew_bool(negative_skew); + const bool_ct positive_skew_bool(positive_skew); + const bool_ct negative_skew_bool(negative_skew); auto to_add = base_point; to_add.y = to_add.y.conditional_negate(negative_skew_bool); element result = accumulator + to_add; // when computing the wNAF we have already validated that positive_skew and negative_skew cannot both be true - bool_t skew_combined = positive_skew_bool ^ negative_skew_bool; + bool_ct skew_combined = positive_skew_bool ^ negative_skew_bool; result.x = accumulator.x.conditional_select(result.x, skew_combined); result.y = accumulator.y.conditional_select(result.y, skew_combined); return result; diff --git a/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_tables.hpp b/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_tables.hpp index 78cc53e03..14effea1d 100644 --- a/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_tables.hpp +++ b/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_tables.hpp @@ -1,4 +1,6 @@ #pragma once +#include "barretenberg/stdlib/primitives/biggroup/biggroup.hpp" +#include "barretenberg/stdlib/primitives/memory/twin_rom_table.hpp" #include "barretenberg/stdlib_circuit_builders/plookup_tables/types.hpp" namespace bb::stdlib { @@ -180,27 +182,27 @@ template element::lookup_table_plookup::lookup_table_plookup(const std::array& inputs) { if constexpr (length == 2) { - auto [A0, A1] = inputs[1].add_sub(inputs[0]); + auto [A0, A1] = inputs[1].checked_unconditional_add_sub(inputs[0]); element_table[0] = A0; element_table[1] = A1; } else if constexpr (length == 3) { - auto [R0, R1] = inputs[1].add_sub(inputs[0]); // B ± A + auto [R0, R1] = inputs[1].checked_unconditional_add_sub(inputs[0]); // B ± A - auto [T0, T1] = inputs[2].add_sub(R0); // C ± (B + A) - auto [T2, T3] = inputs[2].add_sub(R1); // C ± (B - A) + auto [T0, T1] = inputs[2].checked_unconditional_add_sub(R0); // C ± (B + A) + auto [T2, T3] = inputs[2].checked_unconditional_add_sub(R1); // C ± (B - A) element_table[0] = T0; element_table[1] = T2; element_table[2] = T3; element_table[3] = T1; } else if constexpr (length == 4) { - auto [T0, T1] = inputs[1].add_sub(inputs[0]); // B ± A - auto [T2, T3] = inputs[3].add_sub(inputs[2]); // D ± C + auto [T0, T1] = inputs[1].checked_unconditional_add_sub(inputs[0]); // B ± A + auto [T2, T3] = inputs[3].checked_unconditional_add_sub(inputs[2]); // D ± C - auto [F0, F3] = T2.add_sub(T0); // (D + C) ± (B + A) - auto [F1, F2] = T2.add_sub(T1); // (D + C) ± (B - A) - auto [F4, F7] = T3.add_sub(T0); // (D - C) ± (B + A) - auto [F5, F6] = T3.add_sub(T1); // (D - C) ± (B - A) + auto [F0, F3] = T2.checked_unconditional_add_sub(T0); // (D + C) ± (B + A) + auto [F1, F2] = T2.checked_unconditional_add_sub(T1); // (D + C) ± (B - A) + auto [F4, F7] = T3.checked_unconditional_add_sub(T0); // (D - C) ± (B + A) + auto [F5, F6] = T3.checked_unconditional_add_sub(T1); // (D - C) ± (B - A) element_table[0] = F0; element_table[1] = F1; @@ -211,20 +213,20 @@ element::lookup_table_plookup::lookup_table_plookup(con element_table[6] = F6; element_table[7] = F7; } else if constexpr (length == 5) { - auto [A0, A1] = inputs[1].add_sub(inputs[0]); // B ± A - auto [T2, T3] = inputs[3].add_sub(inputs[2]); // D ± C + auto [A0, A1] = inputs[1].checked_unconditional_add_sub(inputs[0]); // B ± A + auto [T2, T3] = inputs[3].checked_unconditional_add_sub(inputs[2]); // D ± C - auto [E0, E3] = inputs[4].add_sub(T2); // E ± (D + C) - auto [E1, E2] = inputs[4].add_sub(T3); // E ± (D - C) + auto [E0, E3] = inputs[4].checked_unconditional_add_sub(T2); // E ± (D + C) + auto [E1, E2] = inputs[4].checked_unconditional_add_sub(T3); // E ± (D - C) - auto [F0, F3] = E0.add_sub(A0); - auto [F1, F2] = E0.add_sub(A1); - auto [F4, F7] = E1.add_sub(A0); - auto [F5, F6] = E1.add_sub(A1); - auto [F8, F11] = E2.add_sub(A0); - auto [F9, F10] = E2.add_sub(A1); - auto [F12, F15] = E3.add_sub(A0); - auto [F13, F14] = E3.add_sub(A1); + auto [F0, F3] = E0.checked_unconditional_add_sub(A0); + auto [F1, F2] = E0.checked_unconditional_add_sub(A1); + auto [F4, F7] = E1.checked_unconditional_add_sub(A0); + auto [F5, F6] = E1.checked_unconditional_add_sub(A1); + auto [F8, F11] = E2.checked_unconditional_add_sub(A0); + auto [F9, F10] = E2.checked_unconditional_add_sub(A1); + auto [F12, F15] = E3.checked_unconditional_add_sub(A0); + auto [F13, F14] = E3.checked_unconditional_add_sub(A1); element_table[0] = F0; element_table[1] = F1; @@ -245,33 +247,33 @@ element::lookup_table_plookup::lookup_table_plookup(con } else if constexpr (length == 6) { // 44 adds! Only use this if it saves us adding another table to a multi-scalar-multiplication - auto [A0, A1] = inputs[1].add_sub(inputs[0]); - auto [E0, E1] = inputs[4].add_sub(inputs[3]); - auto [C0, C3] = inputs[2].add_sub(A0); - auto [C1, C2] = inputs[2].add_sub(A1); + auto [A0, A1] = inputs[1].checked_unconditional_add_sub(inputs[0]); + auto [E0, E1] = inputs[4].checked_unconditional_add_sub(inputs[3]); + auto [C0, C3] = inputs[2].checked_unconditional_add_sub(A0); + auto [C1, C2] = inputs[2].checked_unconditional_add_sub(A1); - auto [F0, F3] = inputs[5].add_sub(E0); - auto [F1, F2] = inputs[5].add_sub(E1); + auto [F0, F3] = inputs[5].checked_unconditional_add_sub(E0); + auto [F1, F2] = inputs[5].checked_unconditional_add_sub(E1); - auto [R0, R7] = F0.add_sub(C0); - auto [R1, R6] = F0.add_sub(C1); - auto [R2, R5] = F0.add_sub(C2); - auto [R3, R4] = F0.add_sub(C3); + auto [R0, R7] = F0.checked_unconditional_add_sub(C0); + auto [R1, R6] = F0.checked_unconditional_add_sub(C1); + auto [R2, R5] = F0.checked_unconditional_add_sub(C2); + auto [R3, R4] = F0.checked_unconditional_add_sub(C3); - auto [S0, S7] = F1.add_sub(C0); - auto [S1, S6] = F1.add_sub(C1); - auto [S2, S5] = F1.add_sub(C2); - auto [S3, S4] = F1.add_sub(C3); + auto [S0, S7] = F1.checked_unconditional_add_sub(C0); + auto [S1, S6] = F1.checked_unconditional_add_sub(C1); + auto [S2, S5] = F1.checked_unconditional_add_sub(C2); + auto [S3, S4] = F1.checked_unconditional_add_sub(C3); - auto [U0, U7] = F2.add_sub(C0); - auto [U1, U6] = F2.add_sub(C1); - auto [U2, U5] = F2.add_sub(C2); - auto [U3, U4] = F2.add_sub(C3); + auto [U0, U7] = F2.checked_unconditional_add_sub(C0); + auto [U1, U6] = F2.checked_unconditional_add_sub(C1); + auto [U2, U5] = F2.checked_unconditional_add_sub(C2); + auto [U3, U4] = F2.checked_unconditional_add_sub(C3); - auto [W0, W7] = F3.add_sub(C0); - auto [W1, W6] = F3.add_sub(C1); - auto [W2, W5] = F3.add_sub(C2); - auto [W3, W4] = F3.add_sub(C3); + auto [W0, W7] = F3.checked_unconditional_add_sub(C0); + auto [W1, W6] = F3.checked_unconditional_add_sub(C1); + auto [W2, W5] = F3.checked_unconditional_add_sub(C2); + auto [W3, W4] = F3.checked_unconditional_add_sub(C3); element_table[0] = R0; element_table[1] = R1; @@ -408,7 +410,7 @@ element::lookup_table_plookup::lookup_table_plookup(con template template element element::lookup_table_plookup::get( - const std::array, length>& bits) const + const std::array& bits) const { std::vector> accumulators; for (size_t i = 0; i < length; ++i) { @@ -558,20 +560,20 @@ element::lookup_table_base::lookup_table_base(const std::a template template element element::lookup_table_base::get( - const std::array, length>& bits) const + const std::array& bits) const { static_assert(length <= 4 && length >= 2); if constexpr (length == 2) { - bool_t table_selector = bits[0] ^ bits[1]; - bool_t sign_selector = bits[1]; + bool_ct table_selector = bits[0] ^ bits[1]; + bool_ct sign_selector = bits[1]; Fq to_add_x = twin0.x.conditional_select(twin1.x, table_selector); Fq to_add_y = twin0.y.conditional_select(twin1.y, table_selector); element to_add(to_add_x, to_add_y.conditional_negate(sign_selector)); return to_add; } else if constexpr (length == 3) { - bool_t t0 = bits[2] ^ bits[0]; - bool_t t1 = bits[2] ^ bits[1]; + bool_ct t0 = bits[2] ^ bits[0]; + bool_ct t1 = bits[2] ^ bits[1]; field_t x_b0 = field_t::select_from_two_bit_table(x_b0_table, t1, t0); field_t x_b1 = field_t::select_from_two_bit_table(x_b1_table, t1, t0); @@ -604,9 +606,9 @@ element element::lookup_table_base::get( return to_add; } else if constexpr (length == 4) { - bool_t t0 = bits[3] ^ bits[0]; - bool_t t1 = bits[3] ^ bits[1]; - bool_t t2 = bits[3] ^ bits[2]; + bool_ct t0 = bits[3] ^ bits[0]; + bool_ct t1 = bits[3] ^ bits[1]; + bool_ct t2 = bits[3] ^ bits[2]; field_t x_b0 = field_t::select_from_three_bit_table(x_b0_table, t2, t1, t0); field_t x_b1 = field_t::select_from_three_bit_table(x_b1_table, t2, t1, t0); diff --git a/cpp/src/barretenberg/stdlib/primitives/biggroup/handle_points_at_infinity.hpp b/cpp/src/barretenberg/stdlib/primitives/biggroup/handle_points_at_infinity.hpp new file mode 100644 index 000000000..b211e08d6 --- /dev/null +++ b/cpp/src/barretenberg/stdlib/primitives/biggroup/handle_points_at_infinity.hpp @@ -0,0 +1,42 @@ +#pragma once +#include "barretenberg/stdlib/primitives/biggroup/biggroup.hpp" + +namespace bb::stdlib { + +/** + * @brief Replace all pairs (∞, scalar) by the pair (one, 0) where one is a fixed generator of the curve + * @details This is a step in enabling our our multiscalar multiplication algorithms to hande points at infinity. + */ +template +std::pair>, std::vector> element::handle_points_at_infinity( + const std::vector& _points, const std::vector& _scalars) +{ + auto builder = _points[0].get_context(); + std::vector points; + std::vector scalars; + element one = element::one(builder); + + for (auto [_point, _scalar] : zip_view(_points, _scalars)) { + bool_ct is_point_at_infinity = _point.is_point_at_infinity(); + if (is_point_at_infinity.get_value() && static_cast(is_point_at_infinity.is_constant())) { + // if point is at infinity and a circuit constant we can just skip. + continue; + } + if (_scalar.get_value() == 0 && _scalar.is_constant()) { + // if scalar multiplier is 0 and also a constant, we can skip + continue; + } + Fq updated_x = Fq::conditional_assign(is_point_at_infinity, one.x, _point.x); + Fq updated_y = Fq::conditional_assign(is_point_at_infinity, one.y, _point.y); + element point(updated_x, updated_y); + Fr scalar = Fr::conditional_assign(is_point_at_infinity, 0, _scalar); + + points.push_back(point); + scalars.push_back(scalar); + // TODO(https://github.com/AztecProtocol/barretenberg/issues/1002): if both point and scalar are constant, don't + // bother adding constraints + } + + return { points, scalars }; +} +} // namespace bb::stdlib diff --git a/cpp/src/barretenberg/stdlib/primitives/curves/secp256r1.hpp b/cpp/src/barretenberg/stdlib/primitives/curves/secp256r1.hpp index a6593e4f8..5b7a5106f 100644 --- a/cpp/src/barretenberg/stdlib/primitives/curves/secp256r1.hpp +++ b/cpp/src/barretenberg/stdlib/primitives/curves/secp256r1.hpp @@ -11,9 +11,9 @@ namespace bb::stdlib { template struct secp256r1 { static constexpr bb::CurveType type = bb::CurveType::SECP256R1; - typedef ::secp256r1::fq fq; - typedef ::secp256r1::fr fr; - typedef ::secp256r1::g1 g1; + typedef bb::secp256r1::fq fq; + typedef bb::secp256r1::fr fr; + typedef bb::secp256r1::g1 g1; typedef CircuitType Builder; typedef witness_t witness_ct; @@ -23,8 +23,8 @@ template struct secp256r1 { typedef bool_t bool_ct; typedef stdlib::uint32 uint32_ct; - typedef bigfield fq_ct; - typedef bigfield bigfr_ct; + typedef bigfield fq_ct; + typedef bigfield bigfr_ct; typedef element g1_ct; typedef element g1_bigfr_ct; }; diff --git a/cpp/src/barretenberg/stdlib_circuit_builders/mega_circuit_builder.cpp b/cpp/src/barretenberg/stdlib_circuit_builders/mega_circuit_builder.cpp index a59353f05..1c1397898 100644 --- a/cpp/src/barretenberg/stdlib_circuit_builders/mega_circuit_builder.cpp +++ b/cpp/src/barretenberg/stdlib_circuit_builders/mega_circuit_builder.cpp @@ -144,6 +144,7 @@ template ecc_op_tuple MegaCircuitBuilder_::queue_ecc_eq() // Add corresponding gates for the operation ecc_op_tuple op_tuple = populate_ecc_op_wires(ultra_op); + op_tuple.return_is_infinity = ultra_op.return_is_infinity; return op_tuple; } diff --git a/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp b/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp index c3f04728c..8de93d520 100644 --- a/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp +++ b/cpp/src/barretenberg/stdlib_circuit_builders/op_queue/ecc_op_queue.hpp @@ -17,6 +17,7 @@ struct UltraOp { Fr y_hi; Fr z_1; Fr z_2; + bool return_is_infinity; }; /** @@ -460,6 +461,7 @@ class ECCOpQueue { const size_t CHUNK_SIZE = 2 * DEFAULT_NON_NATIVE_FIELD_LIMB_BITS; auto x_256 = uint256_t(point.x); auto y_256 = uint256_t(point.y); + ultra_op.return_is_infinity = point.is_point_at_infinity(); ultra_op.x_lo = Fr(x_256.slice(0, CHUNK_SIZE)); ultra_op.x_hi = Fr(x_256.slice(CHUNK_SIZE, CHUNK_SIZE * 2)); ultra_op.y_lo = Fr(y_256.slice(0, CHUNK_SIZE)); diff --git a/cpp/src/barretenberg/vm/tests/avm_inter_table.test.cpp b/cpp/src/barretenberg/vm/tests/avm_inter_table.test.cpp index a13e69df2..c58b42e05 100644 --- a/cpp/src/barretenberg/vm/tests/avm_inter_table.test.cpp +++ b/cpp/src/barretenberg/vm/tests/avm_inter_table.test.cpp @@ -9,6 +9,7 @@ using namespace bb; namespace tests_avm { +using namespace bb; using namespace bb::avm_trace; class AvmInterTableTests : public ::testing::Test { diff --git a/cpp/src/barretenberg/vm/tests/helpers.test.cpp b/cpp/src/barretenberg/vm/tests/helpers.test.cpp index 6b5dadbc0..b1fb19d38 100644 --- a/cpp/src/barretenberg/vm/tests/helpers.test.cpp +++ b/cpp/src/barretenberg/vm/tests/helpers.test.cpp @@ -3,6 +3,7 @@ #include "barretenberg/vm/avm_trace/constants.hpp" #include "barretenberg/vm/generated/avm_flavor.hpp" +using namespace bb; namespace tests_avm { using namespace bb;