Skip to content

Commit

Permalink
Field creation automated through macros (#551)
Browse files Browse the repository at this point in the history
Params files for fields now only require modulus specified by the user
(also twiddle generator and/or non-residue in case either or both are
needed). Everything else gets generated by a macro.
  • Loading branch information
DmytroTym authored Jul 8, 2024
1 parent 73cd4c0 commit 2d4059c
Show file tree
Hide file tree
Showing 14 changed files with 347 additions and 1,696 deletions.
153 changes: 51 additions & 102 deletions icicle/include/fields/field.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public:

static constexpr HOST_DEVICE_INLINE Field from(uint32_t value)
{
storage<TLC> scalar;
storage<TLC> scalar{};
scalar.limbs[0] = value;
for (int i = 1; i < TLC; i++) {
scalar.limbs[i] = 0;
Expand All @@ -58,8 +58,10 @@ public:

if (logn > CONFIG::omegas_count) { THROW_ICICLE_ERR(IcicleError_t::InvalidArgument, "Field: Invalid omega index"); }

storage_array<CONFIG::omegas_count, TLC> const omega = CONFIG::omega;
return Field{omega.storages[logn - 1]};
Field omega = Field{CONFIG::rou};
for (int i = 0; i < CONFIG::omegas_count - logn; i++)
omega = sqr(omega);
return omega;
}

static HOST_INLINE Field omega_inv(uint32_t logn)
Expand All @@ -70,8 +72,10 @@ public:
THROW_ICICLE_ERR(IcicleError_t::InvalidArgument, "Field: Invalid omega_inv index");
}

storage_array<CONFIG::omegas_count, TLC> const omega_inv = CONFIG::omega_inv;
return Field{omega_inv.storages[logn - 1]};
Field omega = inverse(Field{CONFIG::rou});
for (int i = 0; i < CONFIG::omegas_count - logn; i++)
omega = sqr(omega);
return omega;
}

static HOST_DEVICE_INLINE Field inv_log_size(uint32_t logn)
Expand Down Expand Up @@ -182,32 +186,32 @@ public:
if (REDUCTION_SIZE == 0) return xs;
const ff_wide_storage modulus = get_modulus_squared<REDUCTION_SIZE>();
Wide rs = {};
return sub_limbs<true>(xs.limbs_storage, modulus, rs.limbs_storage) ? xs : rs;
return sub_limbs<2 * TLC, true>(xs.limbs_storage, modulus, rs.limbs_storage) ? xs : rs;
}

template <unsigned MODULUS_MULTIPLE = 1>
static constexpr HOST_DEVICE_INLINE Wide neg(const Wide& xs)
{
const ff_wide_storage modulus = get_modulus_squared<MODULUS_MULTIPLE>();
Wide rs = {};
sub_limbs<false>(modulus, xs.limbs_storage, rs.limbs_storage);
sub_limbs<2 * TLC, false>(modulus, xs.limbs_storage, rs.limbs_storage);
return rs;
}

friend HOST_DEVICE_INLINE Wide operator+(Wide xs, const Wide& ys)
{
Wide rs = {};
add_limbs<false>(xs.limbs_storage, ys.limbs_storage, rs.limbs_storage);
add_limbs<2 * TLC, false>(xs.limbs_storage, ys.limbs_storage, rs.limbs_storage);
return sub_modulus_squared<1>(rs);
}

friend HOST_DEVICE_INLINE Wide operator-(Wide xs, const Wide& ys)
{
Wide rs = {};
uint32_t carry = sub_limbs<true>(xs.limbs_storage, ys.limbs_storage, rs.limbs_storage);
uint32_t carry = sub_limbs<2 * TLC, true>(xs.limbs_storage, ys.limbs_storage, rs.limbs_storage);
if (carry == 0) return rs;
const ff_wide_storage modulus = get_modulus_squared<1>();
add_limbs<false>(rs.limbs_storage, modulus, rs.limbs_storage);
add_limbs<2 * TLC, false>(rs.limbs_storage, modulus, rs.limbs_storage);
return rs;
}
};
Expand All @@ -228,12 +232,6 @@ public:
}
}

template <unsigned MULTIPLIER = 1>
static constexpr HOST_DEVICE_INLINE ff_wide_storage modulus_wide()
{
return CONFIG::modulus_wide;
}

// return m
static constexpr HOST_DEVICE_INLINE ff_storage get_m() { return CONFIG::m; }

Expand All @@ -253,12 +251,11 @@ public:
}
}

template <bool SUBTRACT, bool CARRY_OUT>
static constexpr DEVICE_INLINE uint32_t
add_sub_u32_device(const uint32_t* x, const uint32_t* y, uint32_t* r, size_t n = (TLC >> 1))
template <unsigned NLIMBS, bool SUBTRACT, bool CARRY_OUT>
static constexpr DEVICE_INLINE uint32_t add_sub_u32_device(const uint32_t* x, const uint32_t* y, uint32_t* r)
{
r[0] = SUBTRACT ? ptx::sub_cc(x[0], y[0]) : ptx::add_cc(x[0], y[0]);
for (unsigned i = 1; i < n; i++)
for (unsigned i = 1; i < NLIMBS; i++)
r[i] = SUBTRACT ? ptx::subc_cc(x[i], y[i]) : ptx::addc_cc(x[i], y[i]);
if (!CARRY_OUT) {
ptx::addc(0, 0);
Expand All @@ -267,71 +264,35 @@ public:
return SUBTRACT ? ptx::subc(0, 0) : ptx::addc(0, 0);
}

// add or subtract limbs
template <bool SUBTRACT, bool CARRY_OUT>
static constexpr DEVICE_INLINE uint32_t
add_sub_limbs_device(const ff_storage& xs, const ff_storage& ys, ff_storage& rs)
{
const uint32_t* x = xs.limbs;
const uint32_t* y = ys.limbs;
uint32_t* r = rs.limbs;
return add_sub_u32_device<SUBTRACT, CARRY_OUT>(x, y, r, TLC);
}

template <bool SUBTRACT, bool CARRY_OUT>
template <unsigned NLIMBS, bool SUBTRACT, bool CARRY_OUT>
static constexpr DEVICE_INLINE uint32_t
add_sub_limbs_device(const ff_wide_storage& xs, const ff_wide_storage& ys, ff_wide_storage& rs)
{
const uint32_t* x = xs.limbs;
const uint32_t* y = ys.limbs;
uint32_t* r = rs.limbs;
return add_sub_u32_device<SUBTRACT, CARRY_OUT>(x, y, r, 2 * TLC);
}

template <bool SUBTRACT, bool CARRY_OUT>
static constexpr HOST_INLINE uint32_t add_sub_limbs_host(const ff_storage& xs, const ff_storage& ys, ff_storage& rs)
{
const uint32_t* x = xs.limbs;
const uint32_t* y = ys.limbs;
uint32_t* r = rs.limbs;
uint32_t carry = 0;
host_math::carry_chain<TLC, false, CARRY_OUT> chain;
for (unsigned i = 0; i < TLC; i++)
r[i] = SUBTRACT ? chain.sub(x[i], y[i], carry) : chain.add(x[i], y[i], carry);
return CARRY_OUT ? carry : 0;
}

template <bool SUBTRACT, bool CARRY_OUT>
static constexpr HOST_INLINE uint32_t
add_sub_limbs_host(const ff_wide_storage& xs, const ff_wide_storage& ys, ff_wide_storage& rs)
add_sub_limbs_device(const storage<NLIMBS>& xs, const storage<NLIMBS>& ys, storage<NLIMBS>& rs)
{
const uint32_t* x = xs.limbs;
const uint32_t* y = ys.limbs;
uint32_t* r = rs.limbs;
uint32_t carry = 0;
host_math::carry_chain<2 * TLC, false, CARRY_OUT> chain;
for (unsigned i = 0; i < 2 * TLC; i++)
r[i] = SUBTRACT ? chain.sub(x[i], y[i], carry) : chain.add(x[i], y[i], carry);
return CARRY_OUT ? carry : 0;
return add_sub_u32_device<NLIMBS, SUBTRACT, CARRY_OUT>(x, y, r);
}

template <bool CARRY_OUT, typename T>
static constexpr HOST_DEVICE_INLINE uint32_t add_limbs(const T& xs, const T& ys, T& rs)
template <unsigned NLIMBS, bool CARRY_OUT>
static constexpr HOST_DEVICE_INLINE uint32_t
add_limbs(const storage<NLIMBS>& xs, const storage<NLIMBS>& ys, storage<NLIMBS>& rs)
{
#ifdef __CUDA_ARCH__
return add_sub_limbs_device<false, CARRY_OUT>(xs, ys, rs);
return add_sub_limbs_device<NLIMBS, false, CARRY_OUT>(xs, ys, rs);
#else
return add_sub_limbs_host<false, CARRY_OUT>(xs, ys, rs);
return host_math::template add_sub_limbs<NLIMBS, false, CARRY_OUT>(xs, ys, rs);
#endif
}

template <bool CARRY_OUT, typename T>
static constexpr HOST_DEVICE_INLINE uint32_t sub_limbs(const T& xs, const T& ys, T& rs)
template <unsigned NLIMBS, bool CARRY_OUT>
static constexpr HOST_DEVICE_INLINE uint32_t
sub_limbs(const storage<NLIMBS>& xs, const storage<NLIMBS>& ys, storage<NLIMBS>& rs)
{
#ifdef __CUDA_ARCH__
return add_sub_limbs_device<true, CARRY_OUT>(xs, ys, rs);
return add_sub_limbs_device<NLIMBS, true, CARRY_OUT>(xs, ys, rs);
#else
return add_sub_limbs_host<true, CARRY_OUT>(xs, ys, rs);
return host_math::template add_sub_limbs<NLIMBS, true, CARRY_OUT>(xs, ys, rs);
#endif
}

Expand Down Expand Up @@ -531,7 +492,7 @@ public:
// are necessarily NTT-friendly, `b[0]` often turns out to be \f$ 2^{32} - 1 \f$. This actually leads to
// less efficient SASS generated by nvcc, so this case needed separate handling.
if (b[0] == UINT32_MAX) {
add_sub_u32_device<true, false>(c, a, even, TLC);
add_sub_u32_device<TLC, true, false>(c, a, even);
for (i = 0; i < TLC - 1; i++)
odd[i] = a[i];
} else {
Expand Down Expand Up @@ -639,17 +600,18 @@ public:
__align__(16) uint32_t diffs[TLC];
// Differences of halves \f$ a_{hi} - a_{lo}; b_{lo} - b_{hi} \$f are written into `diffs`, signs written to
// `carry1` and `carry2`.
uint32_t carry1 = add_sub_u32_device<true, true>(&a[TLC >> 1], a, diffs);
uint32_t carry2 = add_sub_u32_device<true, true>(b, &b[TLC >> 1], &diffs[TLC >> 1]);
uint32_t carry1 = add_sub_u32_device<(TLC >> 1), true, true>(&a[TLC >> 1], a, diffs);
uint32_t carry2 = add_sub_u32_device<(TLC >> 1), true, true>(b, &b[TLC >> 1], &diffs[TLC >> 1]);
// Compute the "middle part" of Karatsuba: \f$ a_{lo} \cdot b_{hi} + b_{lo} \cdot a_{hi} \f$.
// This is where the assumption about unset high bit of `a` and `b` is relevant.
multiply_and_add_short_raw_device(diffs, &diffs[TLC >> 1], middle_part, r, &r[TLC]);
// Corrections that need to be performed when differences are negative.
// Again, carry doesn't need to be propagated due to unset high bits of `a` and `b`.
if (carry1) add_sub_u32_device<true, false>(&middle_part[TLC >> 1], &diffs[TLC >> 1], &middle_part[TLC >> 1]);
if (carry2) add_sub_u32_device<true, false>(&middle_part[TLC >> 1], diffs, &middle_part[TLC >> 1]);
if (carry1)
add_sub_u32_device<(TLC >> 1), true, false>(&middle_part[TLC >> 1], &diffs[TLC >> 1], &middle_part[TLC >> 1]);
if (carry2) add_sub_u32_device<(TLC >> 1), true, false>(&middle_part[TLC >> 1], diffs, &middle_part[TLC >> 1]);
// Now that middle part is fully correct, it can be added to the result.
add_sub_u32_device<false, true>(&r[TLC >> 1], middle_part, &r[TLC >> 1], TLC);
add_sub_u32_device<TLC, false, true>(&r[TLC >> 1], middle_part, &r[TLC >> 1]);

// Carry from adding middle part has to be propagated to the highest limb.
for (size_t i = TLC + (TLC >> 1); i < 2 * TLC; i++)
Expand All @@ -673,25 +635,12 @@ public:
}
}

static HOST_INLINE void multiply_raw_host(const ff_storage& as, const ff_storage& bs, ff_wide_storage& rs)
{
const uint32_t* a = as.limbs;
const uint32_t* b = bs.limbs;
uint32_t* r = rs.limbs;
for (unsigned i = 0; i < TLC; i++) {
uint32_t carry = 0;
for (unsigned j = 0; j < TLC; j++)
r[j + i] = host_math::madc_cc(a[j], b[i], r[j + i], carry);
r[TLC + i] = carry;
}
}

static HOST_DEVICE_INLINE void multiply_raw(const ff_storage& as, const ff_storage& bs, ff_wide_storage& rs)
{
#ifdef __CUDA_ARCH__
return multiply_raw_device(as, bs, rs);
#else
return multiply_raw_host(as, bs, rs);
return host_math::template multiply_raw<TLC>(as, bs, rs);
#endif
}

Expand All @@ -702,9 +651,9 @@ public:
return multiply_and_add_lsb_neg_modulus_raw_device(as, cs, rs);
#else
Wide r_wide = {};
multiply_raw_host(as, get_neg_modulus(), r_wide.limbs_storage);
host_math::template multiply_raw<TLC>(as, get_neg_modulus(), r_wide.limbs_storage);
Field r = Wide::get_lower(r_wide);
add_limbs<false>(cs, r.limbs_storage, rs);
add_limbs<TLC, false>(cs, r.limbs_storage, rs);
#endif
}

Expand All @@ -713,7 +662,7 @@ public:
#ifdef __CUDA_ARCH__
return multiply_msb_raw_device(as, bs, rs);
#else
return multiply_raw_host(as, bs, rs);
return host_math::template multiply_raw<TLC>(as, bs, rs);
#endif
}

Expand Down Expand Up @@ -759,7 +708,7 @@ public:
if (REDUCTION_SIZE == 0) return xs;
const ff_storage modulus = get_modulus<REDUCTION_SIZE>();
Field rs = {};
return sub_limbs<true>(xs.limbs_storage, modulus, rs.limbs_storage) ? xs : rs;
return sub_limbs<TLC, true>(xs.limbs_storage, modulus, rs.limbs_storage) ? xs : rs;
}

friend std::ostream& operator<<(std::ostream& os, const Field& xs)
Expand All @@ -778,17 +727,17 @@ public:
friend HOST_DEVICE_INLINE Field operator+(Field xs, const Field& ys)
{
Field rs = {};
add_limbs<false>(xs.limbs_storage, ys.limbs_storage, rs.limbs_storage);
add_limbs<TLC, false>(xs.limbs_storage, ys.limbs_storage, rs.limbs_storage);
return sub_modulus<1>(rs);
}

friend HOST_DEVICE_INLINE Field operator-(Field xs, const Field& ys)
{
Field rs = {};
uint32_t carry = sub_limbs<true>(xs.limbs_storage, ys.limbs_storage, rs.limbs_storage);
uint32_t carry = sub_limbs<TLC, true>(xs.limbs_storage, ys.limbs_storage, rs.limbs_storage);
if (carry == 0) return rs;
const ff_storage modulus = get_modulus<1>();
add_limbs<false>(rs.limbs_storage, modulus, rs.limbs_storage);
add_limbs<TLC, false>(rs.limbs_storage, modulus, rs.limbs_storage);
return rs;
}

Expand Down Expand Up @@ -838,10 +787,10 @@ public:
uint32_t carry;
// As mentioned, either 2 or 1 reduction can be performed depending on the field in question.
if (num_of_reductions() == 2) {
carry = sub_limbs<true>(r.limbs_storage, get_modulus<2>(), r_reduced);
carry = sub_limbs<TLC, true>(r.limbs_storage, get_modulus<2>(), r_reduced);
if (carry == 0) r = Field{r_reduced};
}
carry = sub_limbs<true>(r.limbs_storage, get_modulus<1>(), r_reduced);
carry = sub_limbs<TLC, true>(r.limbs_storage, get_modulus<1>(), r_reduced);
if (carry == 0) r = Field{r_reduced};

return r;
Expand Down Expand Up @@ -933,7 +882,7 @@ public:
{
const ff_storage modulus = get_modulus<MODULUS_MULTIPLE>();
Field rs = {};
sub_limbs<false>(modulus, xs.limbs_storage, rs.limbs_storage);
sub_limbs<TLC, false>(modulus, xs.limbs_storage, rs.limbs_storage);
return rs;
}

Expand Down Expand Up @@ -963,7 +912,7 @@ public:
static constexpr HOST_DEVICE_INLINE bool lt(const Field& xs, const Field& ys)
{
ff_storage dummy = {};
uint32_t carry = sub_limbs<true>(xs.limbs_storage, ys.limbs_storage, dummy);
uint32_t carry = sub_limbs<TLC, true>(xs.limbs_storage, ys.limbs_storage, dummy);
return carry;
}

Expand All @@ -983,12 +932,12 @@ public:
while (!(u == one) && !(v == one)) {
while (is_even(u)) {
u = div2(u);
if (is_odd(b)) add_limbs<false>(b.limbs_storage, modulus, b.limbs_storage);
if (is_odd(b)) add_limbs<TLC, false>(b.limbs_storage, modulus, b.limbs_storage);
b = div2(b);
}
while (is_even(v)) {
v = div2(v);
if (is_odd(c)) add_limbs<false>(c.limbs_storage, modulus, c.limbs_storage);
if (is_odd(c)) add_limbs<TLC, false>(c.limbs_storage, modulus, c.limbs_storage);
c = div2(c);
}
if (lt(v, u)) {
Expand Down
Loading

0 comments on commit 2d4059c

Please sign in to comment.