Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add Iterative version of FFT #97

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 161 additions & 77 deletions internal/domain/fft.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package domain

import (
"math/big"
"math/bits"
"slices"
"sync"

bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381"
"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
Expand All @@ -21,7 +24,9 @@ import (
// The elements are returned in order as opposed to being returned in
// bit-reversed order.
func (domain *Domain) FftG1(values []bls12381.G1Affine) []bls12381.G1Affine {
return fftG1(values, domain.Generator)
fftVals := slices.Clone(values)
fftG1(fftVals, domain.Generator)
return fftVals
}

// Computes an IFFT(Inverse Fast Fourier Transform) of the G1 elements.
Expand All @@ -32,7 +37,8 @@ func (domain *Domain) IfftG1(values []bls12381.G1Affine) []bls12381.G1Affine {
var invDomainBI big.Int
domain.CardinalityInv.BigInt(&invDomainBI)

inverseFFT := fftG1(values, domain.GeneratorInv)
inverseFFT := slices.Clone(values)
fftG1(inverseFFT, domain.GeneratorInv)

// scale by the inverse of the domain size
for i := 0; i < len(inverseFFT); i++ {
Expand All @@ -47,61 +53,101 @@ func (domain *Domain) IfftG1(values []bls12381.G1Affine) []bls12381.G1Affine {
// This is the actual implementation of [FftG1] with the same convention.
// That is, the returned slice is in "normal", rather than bit-reversed order.
// We assert that values is a slice of length n==2^i and nthRootOfUnity is a primitive n'th root of unity.
func fftG1(values []bls12381.G1Affine, nthRootOfUnity fr.Element) []bls12381.G1Affine {
n := len(values)
if n == 1 {
return values
// func fftG1(values []bls12381.G1Affine, nthRootOfUnity fr.Element) []bls12381.G1Affine {
// n := len(values)
// if n == 1 {
// return values
// }

// var generatorSquared fr.Element
// generatorSquared.Square(&nthRootOfUnity) // generator with order n/2

// // split the input slice into a (copy of) the values at even resp. odd indices.
// even, odd := takeEvenOdd(values)

// // perform FFT recursively on those parts.
// fftEven := fftG1(even, generatorSquared)
// fftOdd := fftG1(odd, generatorSquared)

// // combine them to get the result
// // - evaluations[k] = fftEven[k] + w^k * fftOdd[k]
// // - evaluations[k] = fftEven[k] - w^k * fftOdd[k]
// // where w is a n'th primitive root of unity.
// inputPoint := fr.One()
// evaluations := make([]bls12381.G1Affine, n)
// for k := 0; k < n/2; k++ {
// var tmp bls12381.G1Affine

// var inputPointBI big.Int
// inputPoint.BigInt(&inputPointBI)

// if inputPoint.IsOne() {
// tmp.Set(&fftOdd[k])
// } else {
// tmp.ScalarMultiplication(&fftOdd[k], &inputPointBI)
// }

// evaluations[k].Add(&fftEven[k], &tmp)
// evaluations[k+n/2].Sub(&fftEven[k], &tmp)

// // we could take this from precomputed values in Domain (as domain.roots[n*k]), but then we would need to pass the domain.
// // At any rate, we don't really need to optimize here.
// inputPoint.Mul(&inputPoint, &nthRootOfUnity)
// }

// return evaluations
// }

func fftG1(a []bls12381.G1Affine, omega fr.Element) {
n := uint(len(a))
logN := log2PowerOf2(uint64(n))
if n != 1<<logN {
panic("input size must be a power of 2")
}

var generatorSquared fr.Element
generatorSquared.Square(&nthRootOfUnity) // generator with order n/2

// split the input slice into a (copy of) the values at even resp. odd indices.
even, odd := takeEvenOdd(values)

// perform FFT recursively on those parts.
fftEven := fftG1(even, generatorSquared)
fftOdd := fftG1(odd, generatorSquared)

// combine them to get the result
// - evaluations[k] = fftEven[k] + w^k * fftOdd[k]
// - evaluations[k] = fftEven[k] - w^k * fftOdd[k]
// where w is a n'th primitive root of unity.
inputPoint := fr.One()
evaluations := make([]bls12381.G1Affine, n)
for k := 0; k < n/2; k++ {
var tmp bls12381.G1Affine

var inputPointBI big.Int
inputPoint.BigInt(&inputPointBI)

if inputPoint.IsOne() {
tmp.Set(&fftOdd[k])
} else {
tmp.ScalarMultiplication(&fftOdd[k], &inputPointBI)
// Bit-reversal permutation
BitReverse(a)

// Main FFT computation
for s := uint(1); s <= logN; s++ {
m := uint(1) << s
halfM := m >> 1
wm := new(fr.Element).Exp(omega, new(big.Int).SetUint64(uint64(n/m)))

var wg sync.WaitGroup
for k := uint(0); k < n; k += m {
wg.Add(1)
go func(k uint) {
defer wg.Done()
w := new(fr.Element).SetOne()
for j := uint(0); j < halfM; j++ {
var t bls12381.G1Affine
var bi big.Int
t.ScalarMultiplication(&a[k+j+halfM], w.BigInt(&bi))
u := a[k+j]
a[k+j].Add(&u, &t)
a[k+j+halfM].Sub(&u, &t)
w.Mul(w, wm)
}
}(k)
}

evaluations[k].Add(&fftEven[k], &tmp)
evaluations[k+n/2].Sub(&fftEven[k], &tmp)

// we could take this from precomputed values in Domain (as domain.roots[n*k]), but then we would need to pass the domain.
// At any rate, we don't really need to optimize here.
inputPoint.Mul(&inputPoint, &nthRootOfUnity)
wg.Wait()
}

return evaluations
}

func (d *Domain) FftFr(values []fr.Element) []fr.Element {
return fftFr(values, d.Generator)
fftVals := slices.Clone(values)
fftFr(fftVals, d.Generator)
return fftVals
}

func (d *Domain) IfftFr(values []fr.Element) []fr.Element {
var invDomain fr.Element
invDomain.SetInt64(int64(len(values)))
invDomain.Inverse(&invDomain)

inverseFFT := fftFr(values, d.GeneratorInv)
inverseFFT := slices.Clone(values)
fftFr(inverseFFT, d.GeneratorInv)

// scale by the inverse of the domain size
for i := 0; i < len(inverseFFT); i++ {
Expand All @@ -110,34 +156,72 @@ func (d *Domain) IfftFr(values []fr.Element) []fr.Element {
return inverseFFT
}

func fftFr(values []fr.Element, nthRootOfUnity fr.Element) []fr.Element {
n := len(values)
if n == 1 {
return values
func log2PowerOf2(n uint64) uint {
if n == 0 || (n&(n-1)) != 0 {
panic("Input must be a power of 2 and not zero")
}

var generatorSquared fr.Element
generatorSquared.Square(&nthRootOfUnity) // generator with order n/2

even, odd := takeEvenOdd(values)

fftEven := fftFr(even, generatorSquared)
fftOdd := fftFr(odd, generatorSquared)
return uint(bits.TrailingZeros64(n))
}

inputPoint := fr.One()
evaluations := make([]fr.Element, n)
for k := 0; k < n/2; k++ {
var tmp fr.Element
tmp.Mul(&inputPoint, &fftOdd[k])
func fftFr(a []fr.Element, omega fr.Element) {
n := uint(len(a))
logN := log2PowerOf2(uint64(n))

evaluations[k].Add(&fftEven[k], &tmp)
evaluations[k+n/2].Sub(&fftEven[k], &tmp)
if n != 1<<logN {
panic("input size must be a power of 2")
}

inputPoint.Mul(&inputPoint, &nthRootOfUnity)
// Bit-reversal permutation
BitReverse(a)

// Main FFT computation
for s := uint(1); s <= logN; s++ {
m := uint(1) << s
halfM := m >> 1
wm := new(fr.Element).Exp(omega, new(big.Int).SetUint64(uint64(n/m)))

for k := uint(0); k < n; k += m {
w := new(fr.Element).SetOne()
for j := uint(0); j < halfM; j++ {
t := new(fr.Element).Mul(&a[k+j+halfM], w)
u := a[k+j]
a[k+j].Add(&u, t)
a[k+j+halfM].Sub(&u, t)
w.Mul(w, wm)
}
}
}
return evaluations
}

// func fftFr(values []fr.Element, nthRootOfUnity fr.Element) []fr.Element {
// n := len(values)
// if n == 1 {
// return values
// }

// var generatorSquared fr.Element
// generatorSquared.Square(&nthRootOfUnity) // generator with order n/2

// even, odd := takeEvenOdd(values)

// fftEven := fftFr(even, generatorSquared)
// fftOdd := fftFr(odd, generatorSquared)

// inputPoint := fr.One()
// evaluations := make([]fr.Element, n)
// for k := 0; k < n/2; k++ {
// var tmp fr.Element
// tmp.Mul(&inputPoint, &fftOdd[k])

// evaluations[k].Add(&fftEven[k], &tmp)
// evaluations[k+n/2].Sub(&fftEven[k], &tmp)

// inputPoint.Mul(&inputPoint, &nthRootOfUnity)
// }
// return evaluations
// }

// takeEvenOdd Takes a slice and return two slices
// The first slice contains (a copy of) all of the elements
// at even indices, the second slice contains
Expand All @@ -146,17 +230,17 @@ func fftFr(values []fr.Element, nthRootOfUnity fr.Element) []fr.Element {
// We assume that the length of the given values slice is even
// so the returned arrays will be the same length.
// This is the case for a radix-2 FFT
func takeEvenOdd[T interface{}](values []T) ([]T, []T) {
n := len(values)
even := make([]T, 0, n/2)
odd := make([]T, 0, n/2)
for i := 0; i < n; i++ {
if i%2 == 0 {
even = append(even, values[i])
} else {
odd = append(odd, values[i])
}
}

return even, odd
}
// func takeEvenOdd[T interface{}](values []T) ([]T, []T) {
// n := len(values)
// even := make([]T, 0, n/2)
// odd := make([]T, 0, n/2)
// for i := 0; i < n; i++ {
// if i%2 == 0 {
// even = append(even, values[i])
// } else {
// odd = append(odd, values[i])
// }
// }

// return even, odd
// }
Loading