Skip to content

Commit

Permalink
copy method for fftG1
Browse files Browse the repository at this point in the history
  • Loading branch information
kevaundray committed Aug 20, 2024
1 parent 24dd658 commit ea272f9
Showing 1 changed file with 84 additions and 48 deletions.
132 changes: 84 additions & 48 deletions internal/domain/fft.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ import (
// The elements are returned in order as opposed to being returned in
// bit-reversed order.
func (domain *Domain) FftG1(values []bls12381.G1Affine) []bls12381.G1Affine {
return fftG1(values, domain.Generator)
fftVals := slices.Clone(values)
fftG1(fftVals, domain.Generator)
return fftVals
}

// Computes an IFFT(Inverse Fast Fourier Transform) of the G1 elements.
Expand All @@ -34,7 +36,8 @@ func (domain *Domain) IfftG1(values []bls12381.G1Affine) []bls12381.G1Affine {
var invDomainBI big.Int
domain.CardinalityInv.BigInt(&invDomainBI)

inverseFFT := fftG1(values, domain.GeneratorInv)
inverseFFT := slices.Clone(values)
fftG1(inverseFFT, domain.GeneratorInv)

// scale by the inverse of the domain size
for i := 0; i < len(inverseFFT); i++ {
Expand All @@ -49,49 +52,82 @@ func (domain *Domain) IfftG1(values []bls12381.G1Affine) []bls12381.G1Affine {
// This is the actual implementation of [FftG1] with the same convention.
// That is, the returned slice is in "normal", rather than bit-reversed order.
// We assert that values is a slice of length n==2^i and nthRootOfUnity is a primitive n'th root of unity.
func fftG1(values []bls12381.G1Affine, nthRootOfUnity fr.Element) []bls12381.G1Affine {
n := len(values)
if n == 1 {
return values
}
// func fftG1(values []bls12381.G1Affine, nthRootOfUnity fr.Element) []bls12381.G1Affine {
// n := len(values)
// if n == 1 {
// return values
// }

// var generatorSquared fr.Element
// generatorSquared.Square(&nthRootOfUnity) // generator with order n/2

// // split the input slice into a (copy of) the values at even resp. odd indices.
// even, odd := takeEvenOdd(values)

var generatorSquared fr.Element
generatorSquared.Square(&nthRootOfUnity) // generator with order n/2
// // perform FFT recursively on those parts.
// fftEven := fftG1(even, generatorSquared)
// fftOdd := fftG1(odd, generatorSquared)

// split the input slice into a (copy of) the values at even resp. odd indices.
even, odd := takeEvenOdd(values)
// // combine them to get the result
// // - evaluations[k] = fftEven[k] + w^k * fftOdd[k]
// // - evaluations[k] = fftEven[k] - w^k * fftOdd[k]
// // where w is a n'th primitive root of unity.
// inputPoint := fr.One()
// evaluations := make([]bls12381.G1Affine, n)
// for k := 0; k < n/2; k++ {
// var tmp bls12381.G1Affine

// perform FFT recursively on those parts.
fftEven := fftG1(even, generatorSquared)
fftOdd := fftG1(odd, generatorSquared)
// var inputPointBI big.Int
// inputPoint.BigInt(&inputPointBI)

// combine them to get the result
// - evaluations[k] = fftEven[k] + w^k * fftOdd[k]
// - evaluations[k] = fftEven[k] - w^k * fftOdd[k]
// where w is a n'th primitive root of unity.
inputPoint := fr.One()
evaluations := make([]bls12381.G1Affine, n)
for k := 0; k < n/2; k++ {
var tmp bls12381.G1Affine
// if inputPoint.IsOne() {
// tmp.Set(&fftOdd[k])
// } else {
// tmp.ScalarMultiplication(&fftOdd[k], &inputPointBI)
// }

var inputPointBI big.Int
inputPoint.BigInt(&inputPointBI)
// evaluations[k].Add(&fftEven[k], &tmp)
// evaluations[k+n/2].Sub(&fftEven[k], &tmp)

if inputPoint.IsOne() {
tmp.Set(&fftOdd[k])
} else {
tmp.ScalarMultiplication(&fftOdd[k], &inputPointBI)
}
// // we could take this from precomputed values in Domain (as domain.roots[n*k]), but then we would need to pass the domain.
// // At any rate, we don't really need to optimize here.
// inputPoint.Mul(&inputPoint, &nthRootOfUnity)
// }

evaluations[k].Add(&fftEven[k], &tmp)
evaluations[k+n/2].Sub(&fftEven[k], &tmp)
// return evaluations
// }

func fftG1(a []bls12381.G1Affine, omega fr.Element) {
n := uint(len(a))
logN := log2PowerOf2(uint64(n))

// we could take this from precomputed values in Domain (as domain.roots[n*k]), but then we would need to pass the domain.
// At any rate, we don't really need to optimize here.
inputPoint.Mul(&inputPoint, &nthRootOfUnity)
if n != 1<<logN {
panic("input size must be a power of 2")
}

return evaluations
// Bit-reversal permutation
BitReverse(a)

// Main FFT computation
for s := uint(1); s <= logN; s++ {
m := uint(1) << s
halfM := m >> 1
wm := new(fr.Element).Exp(omega, new(big.Int).SetUint64(uint64(n/m)))

for k := uint(0); k < n; k += m {
w := new(fr.Element).SetOne()
for j := uint(0); j < halfM; j++ {
var t bls12381.G1Affine
var bi big.Int

t.ScalarMultiplication(&a[k+j+halfM], w.BigInt(&bi))
u := a[k+j]
a[k+j].Add(&u, &t)
a[k+j+halfM].Sub(&u, &t)
w.Mul(w, wm)
}
}
}
}

func (d *Domain) FftFr(values []fr.Element) []fr.Element {
Expand Down Expand Up @@ -189,17 +225,17 @@ func fftFr(a []fr.Element, omega fr.Element) {
// We assume that the length of the given values slice is even
// so the returned arrays will be the same length.
// This is the case for a radix-2 FFT
func takeEvenOdd[T interface{}](values []T) ([]T, []T) {
n := len(values)
even := make([]T, 0, n/2)
odd := make([]T, 0, n/2)
for i := 0; i < n; i++ {
if i%2 == 0 {
even = append(even, values[i])
} else {
odd = append(odd, values[i])
}
}
// func takeEvenOdd[T interface{}](values []T) ([]T, []T) {
// n := len(values)
// even := make([]T, 0, n/2)
// odd := make([]T, 0, n/2)
// for i := 0; i < n; i++ {
// if i%2 == 0 {
// even = append(even, values[i])
// } else {
// odd = append(odd, values[i])
// }
// }

return even, odd
}
// return even, odd
// }

0 comments on commit ea272f9

Please sign in to comment.