Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(DO NOT MERGE): Merge #97 and #98 #99

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
238 changes: 161 additions & 77 deletions internal/domain/fft.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package domain

import (
"math/big"
"math/bits"
"slices"
"sync"

bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381"
"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
Expand All @@ -21,7 +24,9 @@ import (
// The elements are returned in order as opposed to being returned in
// bit-reversed order.
func (domain *Domain) FftG1(values []bls12381.G1Affine) []bls12381.G1Affine {
return fftG1(values, domain.Generator)
fftVals := slices.Clone(values)
fftG1(fftVals, domain.Generator)
return fftVals
}

// Computes an IFFT(Inverse Fast Fourier Transform) of the G1 elements.
Expand All @@ -32,7 +37,8 @@ func (domain *Domain) IfftG1(values []bls12381.G1Affine) []bls12381.G1Affine {
var invDomainBI big.Int
domain.CardinalityInv.BigInt(&invDomainBI)

inverseFFT := fftG1(values, domain.GeneratorInv)
inverseFFT := slices.Clone(values)
fftG1(inverseFFT, domain.GeneratorInv)

// scale by the inverse of the domain size
for i := 0; i < len(inverseFFT); i++ {
Expand All @@ -47,61 +53,101 @@ func (domain *Domain) IfftG1(values []bls12381.G1Affine) []bls12381.G1Affine {
// This is the actual implementation of [FftG1] with the same convention.
// That is, the returned slice is in "normal", rather than bit-reversed order.
// We assert that values is a slice of length n==2^i and nthRootOfUnity is a primitive n'th root of unity.
func fftG1(values []bls12381.G1Affine, nthRootOfUnity fr.Element) []bls12381.G1Affine {
n := len(values)
if n == 1 {
return values
// func fftG1(values []bls12381.G1Affine, nthRootOfUnity fr.Element) []bls12381.G1Affine {
// n := len(values)
// if n == 1 {
// return values
// }

// var generatorSquared fr.Element
// generatorSquared.Square(&nthRootOfUnity) // generator with order n/2

// // split the input slice into a (copy of) the values at even resp. odd indices.
// even, odd := takeEvenOdd(values)

// // perform FFT recursively on those parts.
// fftEven := fftG1(even, generatorSquared)
// fftOdd := fftG1(odd, generatorSquared)

// // combine them to get the result
// // - evaluations[k] = fftEven[k] + w^k * fftOdd[k]
// // - evaluations[k] = fftEven[k] - w^k * fftOdd[k]
// // where w is a n'th primitive root of unity.
// inputPoint := fr.One()
// evaluations := make([]bls12381.G1Affine, n)
// for k := 0; k < n/2; k++ {
// var tmp bls12381.G1Affine

// var inputPointBI big.Int
// inputPoint.BigInt(&inputPointBI)

// if inputPoint.IsOne() {
// tmp.Set(&fftOdd[k])
// } else {
// tmp.ScalarMultiplication(&fftOdd[k], &inputPointBI)
// }

// evaluations[k].Add(&fftEven[k], &tmp)
// evaluations[k+n/2].Sub(&fftEven[k], &tmp)

// // we could take this from precomputed values in Domain (as domain.roots[n*k]), but then we would need to pass the domain.
// // At any rate, we don't really need to optimize here.
// inputPoint.Mul(&inputPoint, &nthRootOfUnity)
// }

// return evaluations
// }

func fftG1(a []bls12381.G1Affine, omega fr.Element) {
n := uint(len(a))
logN := log2PowerOf2(uint64(n))
if n != 1<<logN {
panic("input size must be a power of 2")
}

var generatorSquared fr.Element
generatorSquared.Square(&nthRootOfUnity) // generator with order n/2

// split the input slice into a (copy of) the values at even resp. odd indices.
even, odd := takeEvenOdd(values)

// perform FFT recursively on those parts.
fftEven := fftG1(even, generatorSquared)
fftOdd := fftG1(odd, generatorSquared)

// combine them to get the result
// - evaluations[k] = fftEven[k] + w^k * fftOdd[k]
// - evaluations[k] = fftEven[k] - w^k * fftOdd[k]
// where w is a n'th primitive root of unity.
inputPoint := fr.One()
evaluations := make([]bls12381.G1Affine, n)
for k := 0; k < n/2; k++ {
var tmp bls12381.G1Affine

var inputPointBI big.Int
inputPoint.BigInt(&inputPointBI)

if inputPoint.IsOne() {
tmp.Set(&fftOdd[k])
} else {
tmp.ScalarMultiplication(&fftOdd[k], &inputPointBI)
// Bit-reversal permutation
BitReverse(a)

// Main FFT computation
for s := uint(1); s <= logN; s++ {
m := uint(1) << s
halfM := m >> 1
wm := new(fr.Element).Exp(omega, new(big.Int).SetUint64(uint64(n/m)))

var wg sync.WaitGroup
for k := uint(0); k < n; k += m {
wg.Add(1)
go func(k uint) {
defer wg.Done()
w := new(fr.Element).SetOne()
for j := uint(0); j < halfM; j++ {
var t bls12381.G1Affine
var bi big.Int
t.ScalarMultiplication(&a[k+j+halfM], w.BigInt(&bi))
u := a[k+j]
a[k+j].Add(&u, &t)
a[k+j+halfM].Sub(&u, &t)
w.Mul(w, wm)
}
}(k)
}

evaluations[k].Add(&fftEven[k], &tmp)
evaluations[k+n/2].Sub(&fftEven[k], &tmp)

// we could take this from precomputed values in Domain (as domain.roots[n*k]), but then we would need to pass the domain.
// At any rate, we don't really need to optimize here.
inputPoint.Mul(&inputPoint, &nthRootOfUnity)
wg.Wait()
}

return evaluations
}

func (d *Domain) FftFr(values []fr.Element) []fr.Element {
return fftFr(values, d.Generator)
fftVals := slices.Clone(values)
fftFr(fftVals, d.Generator)
return fftVals
}

func (d *Domain) IfftFr(values []fr.Element) []fr.Element {
var invDomain fr.Element
invDomain.SetInt64(int64(len(values)))
invDomain.Inverse(&invDomain)

inverseFFT := fftFr(values, d.GeneratorInv)
inverseFFT := slices.Clone(values)
fftFr(inverseFFT, d.GeneratorInv)

// scale by the inverse of the domain size
for i := 0; i < len(inverseFFT); i++ {
Expand All @@ -110,34 +156,72 @@ func (d *Domain) IfftFr(values []fr.Element) []fr.Element {
return inverseFFT
}

func fftFr(values []fr.Element, nthRootOfUnity fr.Element) []fr.Element {
n := len(values)
if n == 1 {
return values
func log2PowerOf2(n uint64) uint {
if n == 0 || (n&(n-1)) != 0 {
panic("Input must be a power of 2 and not zero")
}

var generatorSquared fr.Element
generatorSquared.Square(&nthRootOfUnity) // generator with order n/2

even, odd := takeEvenOdd(values)

fftEven := fftFr(even, generatorSquared)
fftOdd := fftFr(odd, generatorSquared)
return uint(bits.TrailingZeros64(n))
}

inputPoint := fr.One()
evaluations := make([]fr.Element, n)
for k := 0; k < n/2; k++ {
var tmp fr.Element
tmp.Mul(&inputPoint, &fftOdd[k])
func fftFr(a []fr.Element, omega fr.Element) {
n := uint(len(a))
logN := log2PowerOf2(uint64(n))

evaluations[k].Add(&fftEven[k], &tmp)
evaluations[k+n/2].Sub(&fftEven[k], &tmp)
if n != 1<<logN {
panic("input size must be a power of 2")
}

inputPoint.Mul(&inputPoint, &nthRootOfUnity)
// Bit-reversal permutation
BitReverse(a)

// Main FFT computation
for s := uint(1); s <= logN; s++ {
m := uint(1) << s
halfM := m >> 1
wm := new(fr.Element).Exp(omega, new(big.Int).SetUint64(uint64(n/m)))

for k := uint(0); k < n; k += m {
w := new(fr.Element).SetOne()
for j := uint(0); j < halfM; j++ {
t := new(fr.Element).Mul(&a[k+j+halfM], w)
u := a[k+j]
a[k+j].Add(&u, t)
a[k+j+halfM].Sub(&u, t)
w.Mul(w, wm)
}
}
}
return evaluations
}

// func fftFr(values []fr.Element, nthRootOfUnity fr.Element) []fr.Element {
// n := len(values)
// if n == 1 {
// return values
// }

// var generatorSquared fr.Element
// generatorSquared.Square(&nthRootOfUnity) // generator with order n/2

// even, odd := takeEvenOdd(values)

// fftEven := fftFr(even, generatorSquared)
// fftOdd := fftFr(odd, generatorSquared)

// inputPoint := fr.One()
// evaluations := make([]fr.Element, n)
// for k := 0; k < n/2; k++ {
// var tmp fr.Element
// tmp.Mul(&inputPoint, &fftOdd[k])

// evaluations[k].Add(&fftEven[k], &tmp)
// evaluations[k+n/2].Sub(&fftEven[k], &tmp)

// inputPoint.Mul(&inputPoint, &nthRootOfUnity)
// }
// return evaluations
// }

// takeEvenOdd Takes a slice and return two slices
// The first slice contains (a copy of) all of the elements
// at even indices, the second slice contains
Expand All @@ -146,17 +230,17 @@ func fftFr(values []fr.Element, nthRootOfUnity fr.Element) []fr.Element {
// We assume that the length of the given values slice is even
// so the returned arrays will be the same length.
// This is the case for a radix-2 FFT
func takeEvenOdd[T interface{}](values []T) ([]T, []T) {
n := len(values)
even := make([]T, 0, n/2)
odd := make([]T, 0, n/2)
for i := 0; i < n; i++ {
if i%2 == 0 {
even = append(even, values[i])
} else {
odd = append(odd, values[i])
}
}

return even, odd
}
// func takeEvenOdd[T interface{}](values []T) ([]T, []T) {
// n := len(values)
// even := make([]T, 0, n/2)
// odd := make([]T, 0, n/2)
// for i := 0; i < n; i++ {
// if i%2 == 0 {
// even = append(even, values[i])
// } else {
// odd = append(odd, values[i])
// }
// }

// return even, odd
// }
21 changes: 20 additions & 1 deletion internal/kzg/kzg_prove.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,34 @@
import (
"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
"github.com/crate-crypto/go-eth-kzg/internal/domain"
"github.com/crate-crypto/go-eth-kzg/internal/poly"
)

func Open(domain *domain.Domain, polyCoeff []fr.Element, evaluationPoint fr.Element, ck *CommitKey, numGoRoutines int) (OpeningProof, error) {

Check failure on line 9 in internal/kzg/kzg_prove.go

View workflow job for this annotation

GitHub Actions / Lint

unnecessary leading newline (whitespace)

outputPoint := poly.PolyEval(polyCoeff, evaluationPoint)

quotient := poly.DividePolyByXminusA(polyCoeff, evaluationPoint)

comm, err := ck.Commit(quotient, 0)
if err != nil {
return OpeningProof{}, nil

Check failure on line 17 in internal/kzg/kzg_prove.go

View workflow job for this annotation

GitHub Actions / Lint

error is not nil (line 15) but it returns nil (nilerr)
}

return OpeningProof{
QuotientCommitment: *comm,
InputPoint: evaluationPoint,
ClaimedValue: outputPoint,
}, nil
}

// Open verifies that a polynomial f(x) when evaluated at a point `z` is equal to `f(z)`
//
// numGoRoutines is used to configure the amount of concurrency needed. Setting this
// value to a negative number or 0 will make it default to the number of CPUs.
//
// [compute_kzg_proof_impl]: https://github.com/ethereum/consensus-specs/blob/017a8495f7671f5fff2075a9bfc9238c1a0982f8/specs/deneb/polynomial-commitments.md#compute_kzg_proof_impl
func Open(domain *domain.Domain, p Polynomial, evaluationPoint fr.Element, ck *CommitKey, numGoRoutines int) (OpeningProof, error) {
func Open_(domain *domain.Domain, p Polynomial, evaluationPoint fr.Element, ck *CommitKey, numGoRoutines int) (OpeningProof, error) {
if len(p) == 0 || len(p) > len(ck.G1) {
return OpeningProof{}, ErrInvalidPolynomialSize
}
Expand Down
6 changes: 3 additions & 3 deletions internal/kzg/kzg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ import (

func TestProofVerifySmoke(t *testing.T) {
domain := domain.NewDomain(4)
srs, _ := newLagrangeSRSInsecure(*domain, big.NewInt(1234))
srs, _ := newMonomialSRSInsecure(*domain, big.NewInt(1234))

// polynomial in lagrange form
// polynomial in monomial form
poly := Polynomial{fr.NewElement(2), fr.NewElement(3), fr.NewElement(4), fr.NewElement(5)}

comm, _ := srs.CommitKey.Commit(poly, 0)
Expand All @@ -29,7 +29,7 @@ func TestProofVerifySmoke(t *testing.T) {

func TestBatchVerifySmoke(t *testing.T) {
domain := domain.NewDomain(4)
srs, _ := newLagrangeSRSInsecure(*domain, big.NewInt(1234))
srs, _ := newMonomialSRSInsecure(*domain, big.NewInt(1234))

numProofs := 10
commitments := make([]Commitment, 0, numProofs)
Expand Down
12 changes: 5 additions & 7 deletions internal/poly/poly.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

result := make([]fr.Element, maxPolyLen)

for i := 0; i < int(minPolyLen); i++ {

Check failure on line 23 in internal/poly/poly.go

View workflow job for this annotation

GitHub Actions / Lint

G115: integer overflow conversion uint64 -> int (gosec)
result[i].Add(&a[i], &b[i])
}

Expand All @@ -28,7 +28,7 @@
// into result
// If b has more coefficients than a, copy the remaining coefficients of b
// and copy them into result
if int(numCoeffs(a)) > int(minPolyLen) {

Check failure on line 31 in internal/poly/poly.go

View workflow job for this annotation

GitHub Actions / Lint

G115: integer overflow conversion uint64 -> int (gosec)
for i := minPolyLen; i < numCoeffs(a); i++ {
result[i].Set(&a[i])
}
Expand Down Expand Up @@ -85,15 +85,13 @@
// PolyEval evaluates a polynomial f(x) at a point z, computing f(z).
// The polynomial is given in coefficient form, and `z` is denoted as inputPoint.
func PolyEval(poly PolynomialCoeff, inputPoint fr.Element) fr.Element {
result := fr.NewElement(0)

for i := len(poly) - 1; i >= 0; i-- {
tmp := fr.Element{}
tmp.Mul(&result, &inputPoint)
result.Add(&tmp, &poly[i])
res := poly[len(poly)-1]
for i := len(poly) - 2; i >= 0; i-- {
res.Mul(&res, &inputPoint)
res.Add(&res, &poly[i])
}

return result
return res
}

// DividePolyByXminusA computes f(x) / (x - a) and returns the quotient.
Expand Down
Loading
Loading