Skip to content

Commit

Permalink
Simplify fft_g1_fast with iterative implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
jtraglia committed May 3, 2024
1 parent e08f22e commit 783ad62
Showing 1 changed file with 111 additions and 99 deletions.
210 changes: 111 additions & 99 deletions src/c_kzg_4844.c
Original file line number Diff line number Diff line change
Expand Up @@ -1549,7 +1549,7 @@ C_KZG_RET verify_blob_kzg_proof_batch(
}

///////////////////////////////////////////////////////////////////////////////
// FFT for G1 points
// Bitwise Functions
///////////////////////////////////////////////////////////////////////////////

/**
Expand All @@ -1566,102 +1566,6 @@ static bool is_power_of_two(uint64_t n) {
return (n & (n - 1)) == 0;
}

/**
* Fast Fourier Transform.
*
* Recursively divide and conquer.
*
* @param[out] out The results (array of length @p n)
* @param[in] in The input data (array of length @p n * @p stride)
* @param[in] stride The input data stride
* @param[in] roots Roots of unity
* (array of length @p n * @p roots_stride)
* @param[in] roots_stride The stride interval among the roots of unity
* @param[in] n Length of the FFT, must be a power of two
*/
static void fft_g1_fast(
g1_t *out,
const g1_t *in,
uint64_t stride,
const fr_t *roots,
uint64_t roots_stride,
uint64_t n
) {
uint64_t half = n / 2;
if (half > 0) { /* Tunable parameter */
fft_g1_fast(out, in, stride * 2, roots, roots_stride * 2, half);
fft_g1_fast(
out + half, in + stride, stride * 2, roots, roots_stride * 2, half
);
for (uint64_t i = 0; i < half; i++) {
g1_t y_times_root;
if (fr_is_one(&roots[i * roots_stride])) {
/* Don't do the scalar multiplication if the scalar is one */
y_times_root = out[i + half];
} else {
g1_mul(&y_times_root, &out[i + half], &roots[i * roots_stride]);
}
g1_sub(&out[i + half], &out[i], &y_times_root);
blst_p1_add_or_double(&out[i], &out[i], &y_times_root);
}
} else {
*out = *in;
}
}

/**
* The entry point for forward FFT over G1 points.
*
* @param[out] out The results (array of length n)
* @param[in] in The input data (array of length n)
* @param[in] n Length of the arrays
* @param[in] s The trusted setup
*
* @remark The array lengths must be a power of two.
* @remark Use ifft_g1 for inverse transformation.
*/
C_KZG_RET fft_g1(g1_t *out, const g1_t *in, size_t n, const KZGSettings *s) {
CHECK(n <= s->max_width);
CHECK(is_power_of_two(n));

uint64_t stride = s->max_width / n;
fft_g1_fast(out, in, 1, s->expanded_roots_of_unity, stride, n);

return C_KZG_OK;
}

/**
* The entry point for inverse FFT over G1 points.
*
* @param[out] out The results (array of length n)
* @param[in] in The input data (array of length n)
* @param[in] n Length of the arrays
* @param[in] s The trusted setup
*
* @remark The array lengths must be a power of two.
* @remark Use fft_g1 for forward transformation.
*/
C_KZG_RET ifft_g1(g1_t *out, const g1_t *in, size_t n, const KZGSettings *s) {
CHECK(n <= s->max_width);
CHECK(is_power_of_two(n));

uint64_t stride = s->max_width / n;
fft_g1_fast(out, in, 1, s->reverse_roots_of_unity, stride, n);

fr_t inv_len;
fr_from_uint64(&inv_len, n);
blst_fr_eucl_inverse(&inv_len, &inv_len);
for (uint64_t i = 0; i < n; i++) {
g1_mul(&out[i], &out[i], &inv_len);
}

return C_KZG_OK;
}

///////////////////////////////////////////////////////////////////////////////
// Trusted Setup Functions
///////////////////////////////////////////////////////////////////////////////

/**
* Reverse the bit order in a 32-bit integer.
*
Expand Down Expand Up @@ -1760,6 +1664,113 @@ static C_KZG_RET bit_reversal_permutation(
return C_KZG_OK;
}


///////////////////////////////////////////////////////////////////////////////
// FFT for G1 points
///////////////////////////////////////////////////////////////////////////////

/**
* Fast Fourier Transform. An iterative implementation.
*
* @param[out] out The results (array of length @p n)
* @param[in] in The input data (array of length @p n)
* @param[in] roots The roots of unity
* @param[in] n The array lengths, must be a power of two
* @param[in] s The trusted setup
*/
static C_KZG_RET fft_g1_fast(
g1_t *out,
const g1_t *in,
const fr_t *roots,
size_t n,
const KZGSettings *s
) {
/* Ensure the length is valid */
if (n > s->max_width || !is_power_of_two(n)) {
return C_KZG_BADARGS;
}

/* Copy values from input to output */
memcpy(out, in, sizeof(g1_t) * n);

/* Bit-reverse permute the order of the values */
C_KZG_RET ret = bit_reversal_permutation(out, sizeof(g1_t), n);
if (ret != C_KZG_OK) return C_KZG_BADARGS;

/* An iterative FFT implementation */
for (size_t stage = 1; stage <= log2_pow2(n); stage++) {
size_t m = 1 << stage;
size_t half_m = m >> 1;
size_t stride = s->max_width / m;

for (size_t i = 0; i < n; i += m) {
for (size_t j = 0; j < half_m; j++) {
g1_t *even = &out[i + j];
g1_t *odd = &out[i + j + half_m];
fr_t twiddle = roots[j * stride];

g1_t odd_times_w;
if (fr_is_one(&twiddle)) {
odd_times_w = *odd;
} else {
g1_mul(&odd_times_w, odd, &twiddle);
}

g1_sub(odd, even, &odd_times_w);
blst_p1_add_or_double(even, even, &odd_times_w);
}
}
}

return C_KZG_OK;
}

/**
* The entry point for forward FFT over G1 points.
*
* @param[out] out The results (array of length n)
* @param[in] in The input data (array of length n)
* @param[in] n Length of the arrays
* @param[in] s The trusted setup
*
* @remark The array lengths must be a power of two.
* @remark Use ifft_g1 for inverse transformation.
*/
C_KZG_RET fft_g1(g1_t *out, const g1_t *in, size_t n, const KZGSettings *s) {
return fft_g1_fast(out, in, s->expanded_roots_of_unity, n, s);
}

/**
* The entry point for inverse FFT over G1 points.
*
* @param[out] out The results (array of length n)
* @param[in] in The input data (array of length n)
* @param[in] n Length of the arrays
* @param[in] s The trusted setup
*
* @remark The array lengths must be a power of two.
* @remark Use fft_g1 for forward transformation.
*/
C_KZG_RET ifft_g1(g1_t *out, const g1_t *in, size_t n, const KZGSettings *s) {
C_KZG_RET ret;

ret = fft_g1_fast(out, in, s->reverse_roots_of_unity, n, s);
if (ret != C_KZG_OK) return ret;

fr_t inv_len;
fr_from_uint64(&inv_len, n);
blst_fr_eucl_inverse(&inv_len, &inv_len);
for (uint64_t i = 0; i < n; i++) {
g1_mul(&out[i], &out[i], &inv_len);
}

return C_KZG_OK;
}

///////////////////////////////////////////////////////////////////////////////
// Trusted Setup Functions
///////////////////////////////////////////////////////////////////////////////

/**
* Generate powers of a root of unity in the field.
*
Expand Down Expand Up @@ -3236,6 +3247,7 @@ C_KZG_RET cells_to_blob(Blob *blob, const Cell *cells) {
* @remark Use recover_all_cells to recover missing cells.
* @remark If cells is NULL, they won't be computed.
* @remark If proofs is NULL, they won't be computed.
* @remark Returns an error if both cells and proofs are NULL.
*/
C_KZG_RET compute_cells_and_kzg_proofs(
Cell *cells, KZGProof *proofs, const Blob *blob, const KZGSettings *s
Expand Down Expand Up @@ -3566,11 +3578,11 @@ static C_KZG_RET compute_r_powers_for_verify_cell_kzg_proof_batch(
}

for (size_t i = 0; i < num_cells; i++) {
/* Copy row id */
/* Copy row index */
bytes_from_uint64(offset, row_indices[i]);
offset += sizeof(uint64_t);

/* Copy column id */
/* Copy column index */
bytes_from_uint64(offset, column_indices[i]);
offset += sizeof(uint64_t);

Expand Down

0 comments on commit 783ad62

Please sign in to comment.