Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: updating field conversion code without pointer hack #4537

Merged
merged 7 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions barretenberg/cpp/src/barretenberg/ecc/curves/bn254/fq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ class Bn254FqParams {
// used in msgpack schema serialization
static constexpr char schema_name[] = "fq";
static constexpr bool has_high_2adicity = false;

// The modulus is larger than BN254 scalar field modulus, so it maps to two BN254 scalars
static constexpr size_t NUM_BN254_SCALARS = 2;
};

using fq = field<Bn254FqParams>;
Expand Down
3 changes: 3 additions & 0 deletions barretenberg/cpp/src/barretenberg/ecc/curves/bn254/fr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ class Bn254FrParams {
// used in msgpack schema serialization
static constexpr char schema_name[] = "fr";
static constexpr bool has_high_2adicity = true;

// This is a BN254 scalar, so it represents one BN254 scalar
static constexpr size_t NUM_BN254_SCALARS = 1;
};

using fr = field<Bn254FrParams>;
Expand Down
59 changes: 2 additions & 57 deletions barretenberg/cpp/src/barretenberg/ecc/fields/field_conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,6 @@ namespace bb::field_conversion {
static constexpr uint64_t NUM_LIMB_BITS = plonk::NUM_LIMB_BITS_IN_FIELD_SIMULATION;
static constexpr uint64_t TOTAL_BITS = 254;

bb::fr convert_from_bn254_frs(std::span<const bb::fr> fr_vec, bb::fr* /*unused*/)
{
ASSERT(fr_vec.size() == 1);
return fr_vec[0];
}

bool convert_from_bn254_frs(std::span<const bb::fr> fr_vec, bool* /*unused*/)
{
ASSERT(fr_vec.size() == 1);
return fr_vec[0] != 0;
}

/**
* @brief Converts 2 bb::fr elements to grumpkin::fr
* @details First, this function must take in 2 bb::fr elements because the grumpkin::fr field has a larger modulus than
Expand All @@ -32,7 +20,7 @@ bool convert_from_bn254_frs(std::span<const bb::fr> fr_vec, bool* /*unused*/)
* @param high_bits_in
* @return grumpkin::fr
*/
grumpkin::fr convert_from_bn254_frs(std::span<const bb::fr> fr_vec, grumpkin::fr* /*unused*/)
grumpkin::fr convert_grumpkin_fr_from_bn254_frs(std::span<const bb::fr> fr_vec)
{
// Combines the two elements into one uint256_t, and then convert that to a grumpkin::fr
ASSERT(uint256_t(fr_vec[0]) < (uint256_t(1) << (NUM_LIMB_BITS * 2))); // lower 136 bits
Expand All @@ -42,25 +30,6 @@ grumpkin::fr convert_from_bn254_frs(std::span<const bb::fr> fr_vec, grumpkin::fr
return result;
}

curve::BN254::AffineElement convert_from_bn254_frs(std::span<const bb::fr> fr_vec,
curve::BN254::AffineElement* /*unused*/)
{
curve::BN254::AffineElement val;
val.x = convert_from_bn254_frs<grumpkin::fr>(fr_vec.subspan(0, 2));
val.y = convert_from_bn254_frs<grumpkin::fr>(fr_vec.subspan(2, 2));
return val;
}

curve::Grumpkin::AffineElement convert_from_bn254_frs(std::span<const bb::fr> fr_vec,
curve::Grumpkin::AffineElement* /*unused*/)
{
ASSERT(fr_vec.size() == 2);
curve::Grumpkin::AffineElement val;
val.x = fr_vec[0];
val.y = fr_vec[1];
return val;
}

/**
* @brief Converts grumpkin::fr to 2 bb::fr elements
* @details First, this function must return 2 bb::fr elements because the grumpkin::fr field has a larger modulus than
Expand All @@ -74,7 +43,7 @@ curve::Grumpkin::AffineElement convert_from_bn254_frs(std::span<const bb::fr> fr
* @param input
* @return std::array<bb::fr, 2>
*/
std::vector<bb::fr> convert_to_bn254_frs(const grumpkin::fr& val)
std::vector<bb::fr> convert_grumpkin_fr_to_bn254_frs(const grumpkin::fr& val)
{
// Goal is to slice up the 64 bit limbs of grumpkin::fr/uint256_t to mirror the 68 bit limbs of bigfield
// We accomplish this by dividing the grumpkin::fr's value into two 68*2=136 bit pieces.
Expand All @@ -89,30 +58,6 @@ std::vector<bb::fr> convert_to_bn254_frs(const grumpkin::fr& val)
return result;
}

std::vector<bb::fr> convert_to_bn254_frs(const bb::fr& val)
{
std::vector<bb::fr> fr_vec{ val };
return fr_vec;
}

std::vector<bb::fr> convert_to_bn254_frs(const curve::BN254::AffineElement& val)
{
auto fr_vec_x = convert_to_bn254_frs(val.x);
auto fr_vec_y = convert_to_bn254_frs(val.y);
std::vector<bb::fr> fr_vec(fr_vec_x.begin(), fr_vec_x.end());
fr_vec.insert(fr_vec.end(), fr_vec_y.begin(), fr_vec_y.end());
return fr_vec;
}

std::vector<bb::fr> convert_to_bn254_frs(const curve::Grumpkin::AffineElement& val)
{
auto fr_vec_x = convert_to_bn254_frs(val.x);
auto fr_vec_y = convert_to_bn254_frs(val.y);
std::vector<bb::fr> fr_vec(fr_vec_x.begin(), fr_vec_x.end());
fr_vec.insert(fr_vec.end(), fr_vec_y.begin(), fr_vec_y.end());
return fr_vec;
}

grumpkin::fr convert_to_grumpkin_fr(const bb::fr& f)
{
const uint64_t NUM_BITS_IN_TWO_LIMBS = 2 * NUM_LIMB_BITS; // the number of bits in 2 bigfield limbs which is 136
Expand Down
225 changes: 64 additions & 161 deletions barretenberg/cpp/src/barretenberg/ecc/fields/field_conversion.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "barretenberg/ecc/curves/bn254/fr.hpp"
#include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp"
#include "barretenberg/polynomials/univariate.hpp"
#include "barretenberg/proof_system/types/circuit_type.hpp"

namespace bb::field_conversion {

Expand All @@ -15,48 +16,22 @@ namespace bb::field_conversion {
* @tparam T
* @return constexpr size_t
*/
template <typename T> constexpr size_t calc_num_bn254_frs();

constexpr size_t calc_num_bn254_frs(bb::fr* /*unused*/)
{
return 1;
}

constexpr size_t calc_num_bn254_frs(grumpkin::fr* /*unused*/)
{
return 2;
}

template <std::integral T> constexpr size_t calc_num_bn254_frs(T* /*unused*/)
{
return 1; // meant for integral types that are less than 254 bits
}

constexpr size_t calc_num_bn254_frs(curve::BN254::AffineElement* /*unused*/)
{
return 2 * calc_num_bn254_frs<curve::BN254::BaseField>();
}

constexpr size_t calc_num_bn254_frs(curve::Grumpkin::AffineElement* /*unused*/)
{
return 2 * calc_num_bn254_frs<curve::Grumpkin::BaseField>();
}

template <typename T, std::size_t N> constexpr size_t calc_num_bn254_frs(std::array<T, N>* /*unused*/)
{
return N * calc_num_bn254_frs<T>();
}

template <typename T, std::size_t N> constexpr size_t calc_num_bn254_frs(bb::Univariate<T, N>* /*unused*/)
{
return N * calc_num_bn254_frs<T>();
}

template <typename T> constexpr size_t calc_num_bn254_frs()
{
return calc_num_bn254_frs(static_cast<T*>(nullptr));
if constexpr (IsAnyOf<T, uint32_t, bool>) {
return 1;
} else if constexpr (IsAnyOf<T, bb::fr, grumpkin::fr>) {
return T::Params::NUM_BN254_SCALARS;
} else if constexpr (IsAnyOf<T, curve::BN254::AffineElement, curve::Grumpkin::AffineElement>) {
return 2 * calc_num_bn254_frs<typename T::Fq>();
} else {
// Array or Univariate
return calc_num_bn254_frs<typename T::value_type>() * (std::tuple_size<T>::value);
}
}

grumpkin::fr convert_grumpkin_fr_from_bn254_frs(std::span<const bb::fr> fr_vec);

/**
* @brief Conversions from vector of bb::fr elements to transcript types.
* @details We want to support the following types: bool, size_t, uint32_t, uint64_t, bb::fr, grumpkin::fr,
Expand All @@ -68,75 +43,40 @@ template <typename T> constexpr size_t calc_num_bn254_frs()
* @param fr_vec
* @return T
*/
template <typename T> T convert_from_bn254_frs(std::span<const bb::fr> fr_vec);

bool convert_from_bn254_frs(std::span<const bb::fr> fr_vec, bool* /*unused*/);

template <std::integral T> inline T convert_from_bn254_frs(std::span<const bb::fr> fr_vec, T* /*unused*/)
{
ASSERT(fr_vec.size() == 1);
return static_cast<T>(fr_vec[0]);
}

bb::fr convert_from_bn254_frs(std::span<const bb::fr> fr_vec, bb::fr* /*unused*/);

grumpkin::fr convert_from_bn254_frs(std::span<const bb::fr> fr_vec, grumpkin::fr* /*unused*/);

curve::BN254::AffineElement convert_from_bn254_frs(std::span<const bb::fr> fr_vec,
curve::BN254::AffineElement* /*unused*/);

curve::Grumpkin::AffineElement convert_from_bn254_frs(std::span<const bb::fr> fr_vec,
curve::Grumpkin::AffineElement* /*unused*/);

template <size_t N>
inline std::array<bb::fr, N> convert_from_bn254_frs(std::span<const bb::fr> fr_vec, std::array<bb::fr, N>* /*unused*/)
{
std::array<bb::fr, N> val;
for (size_t i = 0; i < N; ++i) {
val[i] = fr_vec[i];
}
return val;
}

template <size_t N>
inline std::array<grumpkin::fr, N> convert_from_bn254_frs(std::span<const bb::fr> fr_vec,
std::array<grumpkin::fr, N>* /*unused*/)
{
std::array<grumpkin::fr, N> val;
for (size_t i = 0; i < N; ++i) {
std::vector<bb::fr> fr_vec_tmp{ fr_vec[2 * i],
fr_vec[2 * i + 1] }; // each pair of consecutive elements is a grumpkin::fr
val[i] = convert_from_bn254_frs<grumpkin::fr>(fr_vec_tmp);
}
return val;
}

template <size_t N>
inline Univariate<bb::fr, N> convert_from_bn254_frs(std::span<const bb::fr> fr_vec, Univariate<bb::fr, N>* /*unused*/)
{
Univariate<bb::fr, N> val;
for (size_t i = 0; i < N; ++i) {
val.evaluations[i] = fr_vec[i];
}
return val;
}

template <size_t N>
inline Univariate<grumpkin::fr, N> convert_from_bn254_frs(std::span<const bb::fr> fr_vec,
Univariate<grumpkin::fr, N>* /*unused*/)
template <typename T> T convert_from_bn254_frs(std::span<const bb::fr> fr_vec)
{
Univariate<grumpkin::fr, N> val;
for (size_t i = 0; i < N; ++i) {
std::vector<bb::fr> fr_vec_tmp{ fr_vec[2 * i], fr_vec[2 * i + 1] };
val.evaluations[i] = convert_from_bn254_frs<grumpkin::fr>(fr_vec_tmp);
if constexpr (IsAnyOf<T, bool>) {
ASSERT(fr_vec.size() == 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels a little funny. Can we extract the bottom bit, ASSERT it's 0 or 1, and then return?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed. calling the bool() operator will assert that its 0 or 1.

return fr_vec[0] != 0;
} else if constexpr (IsAnyOf<T, uint32_t, bb::fr>) {
ASSERT(fr_vec.size() == 1);
return static_cast<T>(fr_vec[0]);
} else if constexpr (IsAnyOf<T, grumpkin::fr>) {
ASSERT(fr_vec.size() == 2);
return convert_grumpkin_fr_from_bn254_frs(fr_vec);
} else if constexpr (IsAnyOf<T, curve::BN254::AffineElement, curve::Grumpkin::AffineElement>) {
using BaseField = typename T::Fq;
constexpr size_t BaseFieldScalarSize = calc_num_bn254_frs<BaseField>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've historically used THIS_CASE for constexpr values, and ThisCase for classes. Could you switch here and elsewhere?

Copy link
Contributor Author

@lucasxia01 lucasxia01 Feb 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done for BaseFieldScalarSize, not sure if I missed anything.

ASSERT(fr_vec.size() == 2 * BaseFieldScalarSize);
T val;
val.x = convert_from_bn254_frs<BaseField>(fr_vec.subspan(0, BaseFieldScalarSize));
val.y = convert_from_bn254_frs<BaseField>(fr_vec.subspan(BaseFieldScalarSize, BaseFieldScalarSize));
return val;
} else {
// Array or Univariate
T val;
constexpr size_t FieldScalarSize = calc_num_bn254_frs<typename T::value_type>();
ASSERT(fr_vec.size() == FieldScalarSize * std::tuple_size<T>::value);
size_t i = 0;
for (auto& x : val) {
x = convert_from_bn254_frs<typename T::value_type>(fr_vec.subspan(FieldScalarSize * i, FieldScalarSize));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels a little funny, shouldn't the second subspan argument have an i+1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

subspan takes in an offset and a count, not a start and end.

++i;
}
return val;
}
return val;
}

template <typename T> T convert_from_bn254_frs(std::span<const bb::fr> fr_vec)
{
return convert_from_bn254_frs(fr_vec, static_cast<T*>(nullptr));
}
std::vector<bb::fr> convert_grumpkin_fr_to_bn254_frs(const grumpkin::fr& val);

/**
* @brief Conversion from transcript values to bb::frs
Expand All @@ -147,65 +87,28 @@ template <typename T> T convert_from_bn254_frs(std::span<const bb::fr> fr_vec)
* @param val
* @return std::vector<bb::fr>
*/
template <std::integral T> std::vector<bb::fr> inline convert_to_bn254_frs(const T& val)
{
std::vector<bb::fr> fr_vec{ val };
return fr_vec;
}

std::vector<bb::fr> convert_to_bn254_frs(const grumpkin::fr& val);

std::vector<bb::fr> convert_to_bn254_frs(const bb::fr& val);

std::vector<bb::fr> convert_to_bn254_frs(const curve::BN254::AffineElement& val);

std::vector<bb::fr> convert_to_bn254_frs(const curve::Grumpkin::AffineElement& val);

template <size_t N> std::vector<bb::fr> inline convert_to_bn254_frs(const std::array<bb::fr, N>& val)
{
std::vector<bb::fr> fr_vec(val.begin(), val.end());
return fr_vec;
}

template <size_t N> std::vector<bb::fr> inline convert_to_bn254_frs(const std::array<grumpkin::fr, N>& val)
{
std::vector<bb::fr> fr_vec;
for (size_t i = 0; i < N; ++i) {
auto tmp_vec = convert_to_bn254_frs(val[i]);
fr_vec.insert(fr_vec.end(), tmp_vec.begin(), tmp_vec.end());
}
return fr_vec;
}

template <size_t N> std::vector<bb::fr> inline convert_to_bn254_frs(const bb::Univariate<bb::fr, N>& val)
{
std::vector<bb::fr> fr_vec;
for (size_t i = 0; i < N; ++i) {
auto tmp_vec = convert_to_bn254_frs(val.evaluations[i]);
fr_vec.insert(fr_vec.end(), tmp_vec.begin(), tmp_vec.end());
}
return fr_vec;
}

template <size_t N> std::vector<bb::fr> inline convert_to_bn254_frs(const bb::Univariate<grumpkin::fr, N>& val)
{
std::vector<bb::fr> fr_vec;
for (size_t i = 0; i < N; ++i) {
auto tmp_vec = convert_to_bn254_frs(val.evaluations[i]);
fr_vec.insert(fr_vec.end(), tmp_vec.begin(), tmp_vec.end());
}
return fr_vec;
}

template <typename AllValues> std::vector<bb::fr> inline convert_to_bn254_frs(const AllValues& val)
{
auto data = val.get_all();
std::vector<bb::fr> fr_vec;
for (auto& item : data) {
auto tmp_vec = convert_to_bn254_frs(item);
fr_vec.insert(fr_vec.end(), tmp_vec.begin(), tmp_vec.end());
template <typename T> std::vector<bb::fr> convert_to_bn254_frs(const T& val)
{
if constexpr (IsAnyOf<T, bool, uint32_t, bb::fr>) {
std::vector<bb::fr> fr_vec{ val };
return fr_vec;
} else if constexpr (IsAnyOf<T, grumpkin::fr>) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still am oldschool so constexpr is a weird way to do this for me, but if the team likes it :)

return convert_grumpkin_fr_to_bn254_frs(val);
} else if constexpr (IsAnyOf<T, curve::BN254::AffineElement, curve::Grumpkin::AffineElement>) {
auto fr_vec_x = convert_to_bn254_frs(val.x);
auto fr_vec_y = convert_to_bn254_frs(val.y);
std::vector<bb::fr> fr_vec(fr_vec_x.begin(), fr_vec_x.end());
fr_vec.insert(fr_vec.end(), fr_vec_y.begin(), fr_vec_y.end());
return fr_vec;
} else {
// Array or Univariate
std::vector<bb::fr> fr_vec;
for (auto& x : val) {
auto tmp_vec = convert_to_bn254_frs(x);
fr_vec.insert(fr_vec.end(), tmp_vec.begin(), tmp_vec.end());
}
return fr_vec;
}
return fr_vec;
}

grumpkin::fr convert_to_grumpkin_fr(const bb::fr& f);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
namespace bb::group_elements {
template <typename T>
concept SupportsHashToCurve = T::can_hash_to_curve;
template <typename Fq, typename Fr, typename Params> class alignas(64) affine_element {
template <typename Fq_, typename Fr_, typename Params> class alignas(64) affine_element {
public:
using Fq = Fq_;
using Fr = Fr_;

using in_buf = const uint8_t*;
using vec_in_buf = const uint8_t*;
using out_buf = uint8_t*;
Expand Down
Loading
Loading