Skip to content

Commit

Permalink
add iterative version of fft
Browse files Browse the repository at this point in the history
  • Loading branch information
kevaundray committed Aug 20, 2024
1 parent c755bb3 commit 24dd658
Showing 1 changed file with 65 additions and 22 deletions.
87 changes: 65 additions & 22 deletions internal/domain/fft.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package domain

import (
"math/big"
"math/bits"
"slices"

bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381"
"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
Expand Down Expand Up @@ -93,15 +95,18 @@ func fftG1(values []bls12381.G1Affine, nthRootOfUnity fr.Element) []bls12381.G1A
}

func (d *Domain) FftFr(values []fr.Element) []fr.Element {
return fftFr(values, d.Generator)
fftVals := slices.Clone(values)
fftFr(fftVals, d.Generator)
return fftVals
}

func (d *Domain) IfftFr(values []fr.Element) []fr.Element {
var invDomain fr.Element
invDomain.SetInt64(int64(len(values)))
invDomain.Inverse(&invDomain)

inverseFFT := fftFr(values, d.GeneratorInv)
inverseFFT := slices.Clone(values)
fftFr(inverseFFT, d.GeneratorInv)

// scale by the inverse of the domain size
for i := 0; i < len(inverseFFT); i++ {
Expand All @@ -110,34 +115,72 @@ func (d *Domain) IfftFr(values []fr.Element) []fr.Element {
return inverseFFT
}

func fftFr(values []fr.Element, nthRootOfUnity fr.Element) []fr.Element {
n := len(values)
if n == 1 {
return values
func log2PowerOf2(n uint64) uint {
if n == 0 || (n&(n-1)) != 0 {
panic("Input must be a power of 2 and not zero")
}

var generatorSquared fr.Element
generatorSquared.Square(&nthRootOfUnity) // generator with order n/2

even, odd := takeEvenOdd(values)

fftEven := fftFr(even, generatorSquared)
fftOdd := fftFr(odd, generatorSquared)
return uint(bits.TrailingZeros64(n))
}

inputPoint := fr.One()
evaluations := make([]fr.Element, n)
for k := 0; k < n/2; k++ {
var tmp fr.Element
tmp.Mul(&inputPoint, &fftOdd[k])
func fftFr(a []fr.Element, omega fr.Element) {
n := uint(len(a))
logN := log2PowerOf2(uint64(n))

evaluations[k].Add(&fftEven[k], &tmp)
evaluations[k+n/2].Sub(&fftEven[k], &tmp)
if n != 1<<logN {
panic("input size must be a power of 2")
}

inputPoint.Mul(&inputPoint, &nthRootOfUnity)
// Bit-reversal permutation
BitReverse(a)

// Main FFT computation
for s := uint(1); s <= logN; s++ {
m := uint(1) << s
halfM := m >> 1
wm := new(fr.Element).Exp(omega, new(big.Int).SetUint64(uint64(n/m)))

for k := uint(0); k < n; k += m {
w := new(fr.Element).SetOne()
for j := uint(0); j < halfM; j++ {
t := new(fr.Element).Mul(&a[k+j+halfM], w)
u := a[k+j]
a[k+j].Add(&u, t)
a[k+j+halfM].Sub(&u, t)
w.Mul(w, wm)
}
}
}
return evaluations
}

// func fftFr(values []fr.Element, nthRootOfUnity fr.Element) []fr.Element {
// n := len(values)
// if n == 1 {
// return values
// }

// var generatorSquared fr.Element
// generatorSquared.Square(&nthRootOfUnity) // generator with order n/2

// even, odd := takeEvenOdd(values)

// fftEven := fftFr(even, generatorSquared)
// fftOdd := fftFr(odd, generatorSquared)

// inputPoint := fr.One()
// evaluations := make([]fr.Element, n)
// for k := 0; k < n/2; k++ {
// var tmp fr.Element
// tmp.Mul(&inputPoint, &fftOdd[k])

// evaluations[k].Add(&fftEven[k], &tmp)
// evaluations[k+n/2].Sub(&fftEven[k], &tmp)

// inputPoint.Mul(&inputPoint, &nthRootOfUnity)
// }
// return evaluations
// }

// takeEvenOdd Takes a slice and return two slices
// The first slice contains (a copy of) all of the elements
// at even indices, the second slice contains
Expand Down

0 comments on commit 24dd658

Please sign in to comment.