Skip to content

Commit

Permalink
feat: added barrett_reduction implementation into uintx (#6768)
Browse files Browse the repository at this point in the history
This PR adds a `barrett_reduction` method into `unitx`, a fast division
algorithm when the divisor is known ahead of time such that precomputed
factors can be determined.

`barrett_reduction` is used to speed up `divmod` for some important
hardcoded moduli. Or particular relevance is the prime field associated
with BN254 curve arithmetic, as expensive 1024-bit `divmod` operations
are performed when computing witnesses within `stdlib::bitfield` -
commonly used to perform non-native BN254 curve arithmetic.

Speeds up biggroup batch_mul 4x

---------

Co-authored-by: Rumata888 <isennovskiy@gmail.com>
  • Loading branch information
2 people authored and AztecBot committed Jul 13, 2024
1 parent 89d723d commit ad5e73b
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 1 deletion.
1 change: 1 addition & 0 deletions cpp/src/barretenberg/benchmark/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ add_subdirectory(indexed_tree_bench)
add_subdirectory(append_only_tree_bench)
add_subdirectory(ultra_bench)
add_subdirectory(stdlib_hash)
add_subdirectory(circuit_construction_bench)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
barretenberg_module(circuit_construction_bench stdlib_primitives)
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@

#include <benchmark/benchmark.h>

#include "barretenberg/stdlib/primitives/biggroup/biggroup.hpp"
#include "barretenberg/stdlib/primitives/curves/bn254.hpp"
#include "barretenberg/stdlib_circuit_builders/ultra_circuit_builder.hpp"

using namespace benchmark;
using namespace bb;

namespace {

auto& engine = numeric::get_debug_randomness();
void biggroup_construction_bench(State& state)
{
using Curve = stdlib::bn254<UltraCircuitBuilder>;
using affine_element = Curve::AffineElementNative;
using element_ct = Curve::Element;
using scalar_ct = Curve::ScalarField;
for (auto _ : state) {
state.PauseTiming();

UltraCircuitBuilder builder;
size_t num_points = static_cast<size_t>(state.range(0));
std::vector<affine_element> points;
std::vector<fr> scalars;
for (size_t i = 0; i < num_points; ++i) {
points.push_back(affine_element(Curve::ElementNative::random_element()));
scalars.push_back(fr::random_element());
}

std::vector<element_ct> circuit_points;
std::vector<scalar_ct> circuit_scalars;
for (size_t i = 0; i < num_points; ++i) {
circuit_points.push_back(element_ct::from_witness(&builder, points[i]));
circuit_scalars.push_back(scalar_ct::from_witness(&builder, scalars[i]));
}
state.ResumeTiming();
element_ct::batch_mul(circuit_points, circuit_scalars);
state.PauseTiming();
}
}
} // namespace
BENCHMARK(biggroup_construction_bench)->Unit(kMicrosecond)->DenseRange(2, 20);

BENCHMARK_MAIN();
2 changes: 2 additions & 0 deletions cpp/src/barretenberg/numeric/uintx/uintx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ template <class base_uint> class uintx {
base_uint lo;
base_uint hi;

template <base_uint modulus> constexpr std::pair<uintx, uintx> barrett_reduction() const;
constexpr std::pair<uintx, uintx> divmod(const uintx& b) const;
constexpr std::pair<uintx, uintx> divmod_base(const uintx& b) const;
};

template <typename B, typename Params> inline void read(B& it, uintx<Params>& value)
Expand Down
33 changes: 33 additions & 0 deletions cpp/src/barretenberg/numeric/uintx/uintx.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,39 @@ namespace {
auto& engine = numeric::get_debug_randomness();
} // namespace

TEST(uintx, BarrettReduction512)
{
uint512_t x = engine.get_random_uint512();

static constexpr uint64_t modulus_0 = 0x3C208C16D87CFD47UL;
static constexpr uint64_t modulus_1 = 0x97816a916871ca8dUL;
static constexpr uint64_t modulus_2 = 0xb85045b68181585dUL;
static constexpr uint64_t modulus_3 = 0x30644e72e131a029UL;
constexpr uint256_t modulus(modulus_0, modulus_1, modulus_2, modulus_3);

const auto [quotient_result, remainder_result] = x.barrett_reduction<modulus>();
const auto [quotient_expected, remainder_expected] = x.divmod_base(uint512_t(modulus));
EXPECT_EQ(quotient_result, quotient_expected);
EXPECT_EQ(remainder_result, remainder_expected);
}

TEST(uintx, BarrettReduction1024)
{
uint1024_t x = engine.get_random_uint1024();

static constexpr uint64_t modulus_0 = 0x3C208C16D87CFD47UL;
static constexpr uint64_t modulus_1 = 0x97816a916871ca8dUL;
static constexpr uint64_t modulus_2 = 0xb85045b68181585dUL;
static constexpr uint64_t modulus_3 = 0x30644e72e131a029UL;
constexpr uint256_t modulus_partial(modulus_0, modulus_1, modulus_2, modulus_3);
constexpr uint512_t modulus = uint512_t(modulus_partial) * uint512_t(modulus_partial);

const auto [quotient_result, remainder_result] = x.barrett_reduction<modulus>();
const auto [quotient_expected, remainder_expected] = x.divmod_base(uint1024_t(modulus));
EXPECT_EQ(quotient_result, quotient_expected);
EXPECT_EQ(remainder_result, remainder_expected);
}

TEST(uintx, GetBit)
{
constexpr uint256_t lo{ 0b0110011001110010011001100111001001100110011100100110011001110011,
Expand Down
85 changes: 84 additions & 1 deletion cpp/src/barretenberg/numeric/uintx/uintx_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

namespace bb::numeric {
template <class base_uint>
constexpr std::pair<uintx<base_uint>, uintx<base_uint>> uintx<base_uint>::divmod(const uintx& b) const
constexpr std::pair<uintx<base_uint>, uintx<base_uint>> uintx<base_uint>::divmod_base(const uintx& b) const
{
ASSERT(b != 0);
if (*this == 0) {
Expand Down Expand Up @@ -336,4 +336,87 @@ template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator
}
return result;
}

template <class base_uint>
constexpr std::pair<uintx<base_uint>, uintx<base_uint>> uintx<base_uint>::divmod(const uintx& b) const
{
constexpr uint256_t BN254FQMODULUS256 =
uint256_t(0x3C208C16D87CFD47UL, 0x97816a916871ca8dUL, 0xb85045b68181585dUL, 0x30644e72e131a029UL);
constexpr uint256_t SECP256K1FQMODULUS256 =
uint256_t(0xFFFFFFFEFFFFFC2FULL, 0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL);
constexpr uint256_t SECP256R1FQMODULUS256 =
uint256_t(0xFFFFFFFFFFFFFFFFULL, 0x00000000FFFFFFFFULL, 0x0000000000000000ULL, 0xFFFFFFFF00000001ULL);

if (b == uintx(BN254FQMODULUS256)) {
return (*this).template barrett_reduction<BN254FQMODULUS256>();
}
if (b == uintx(SECP256K1FQMODULUS256)) {
return (*this).template barrett_reduction<SECP256K1FQMODULUS256>();
}
if (b == uintx(SECP256R1FQMODULUS256)) {
return (*this).template barrett_reduction<SECP256R1FQMODULUS256>();
}

return divmod_base(b);
}

/**
* @brief Compute fast division via a barrett reduction
* Evaluates x = qm + r where m = modulus. returns q, r
* @details This implementation is less efficient due to making no assumptions about the value of *self.
* When using this method to perform modular reductions e.g. (*self) mod m, if (*self) < m^2 a lot of the
* `uintx` operations in this method could be replaced with `base_uint` operations
*
* @tparam base_uint
* @tparam modulus
* @return constexpr std::pair<uintx<base_uint>, uintx<base_uint>>
*/
template <class base_uint>
template <base_uint modulus>
constexpr std::pair<uintx<base_uint>, uintx<base_uint>> uintx<base_uint>::barrett_reduction() const
{
// N.B. k could be modulus.get_msb() + 1 if we have strong bounds on the max value of (*self)
// (a smaller k would allow us to fit `redc_parameter` into `base_uint` and not `uintx`)
constexpr size_t k = base_uint::length() - 1;
// N.B. computation of redc_parameter requires division operation - if this cannot be precomputed (or amortized over
// multiple reductions over the same modulus), barrett_reduction is much slower than divmod
constexpr uintx redc_parameter = ((uintx(1) << (k * 2)).divmod_base(uintx(modulus))).first;

const auto x = *this;

// compute x * redc_parameter
const auto mul_result = x.mul_extended(redc_parameter);
constexpr size_t shift = 2 * k;

// compute (x * redc_parameter) >> 2k
// This is equivalent to (x * (2^{2k} / modulus) / 2^{2k})
// which approximates to x / modulus
const uintx downshifted_hi_bits = mul_result.second & ((uintx(1) << shift) - 1);
const uintx mul_hi_underflow = uintx(downshifted_hi_bits) << (length() - shift);
uintx quotient = (mul_result.first >> shift) | mul_hi_underflow;

// compute remainder by determining value of x - quotient * modulus
uintx qm_lo(0);
{
const auto lolo = quotient.lo.mul_extended(modulus);
const auto lohi = quotient.hi.mul_extended(modulus);
base_uint t0 = lolo.first;
base_uint t1 = lolo.second;
t1 = t1 + lohi.first;
qm_lo = uintx(t0, t1);
}
uintx remainder = x - qm_lo;

// because redc_parameter is an imperfect representation of 2^{2k} / n (might be too small),
// the computed quotient may be off by up to 3 (classic algorithm should be up to 1,
// TODO(https://github.com/AztecProtocol/barretenberg/issues/1051): investigate, why)
size_t i = 0;
while (remainder >= uintx(modulus)) {
ASSERT(i < 3);
remainder = remainder - modulus;
quotient = quotient + 1;
i++;
}
return std::make_pair(quotient, remainder);
}
} // namespace bb::numeric

0 comments on commit ad5e73b

Please sign in to comment.