diff --git a/ecc/bls12-377/fr/fft/domain.go b/ecc/bls12-377/fr/fft/domain.go index b14cd48f0..355a1f8aa 100644 --- a/ecc/bls12-377/fr/fft/domain.go +++ b/ecc/bls12-377/fr/fft/domain.go @@ -53,12 +53,10 @@ type Domain struct { // we precompute these mostly to avoid the memory intensive bit reverse permutation in the groth16.Prover // CosetTable u*<1,g,..,g^(n-1)> - CosetTable []fr.Element - CosetTableReversed []fr.Element // optional, this is computed on demand at the creation of the domain + CosetTable []fr.Element // CosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j - CosetTableInv []fr.Element - CosetTableInvReversed []fr.Element // optional, this is computed on demand at the creation of the domain + CosetTableInv []fr.Element } // NewDomain returns a subgroup with a power of 2 cardinality @@ -90,9 +88,6 @@ func NewDomain(m uint64, shift ...fr.Element) *Domain { // twiddle factors domain.preComputeTwiddles() - // store the bit reversed coset tables - domain.reverseCosetTables() - return domain } @@ -118,15 +113,6 @@ func Generator(m uint64) (fr.Element, error) { return generator, nil } -func (d *Domain) reverseCosetTables() { - d.CosetTableReversed = make([]fr.Element, d.Cardinality) - d.CosetTableInvReversed = make([]fr.Element, d.Cardinality) - copy(d.CosetTableReversed, d.CosetTable) - copy(d.CosetTableInvReversed, d.CosetTableInv) - BitReverse(d.CosetTableReversed) - BitReverse(d.CosetTableInvReversed) -} - func (d *Domain) preComputeTwiddles() { // nb fft stages @@ -252,9 +238,6 @@ func (d *Domain) ReadFrom(r io.Reader) (int64, error) { // twiddle factors d.preComputeTwiddles() - // store the bit reversed coset tables if needed - d.reverseCosetTables() - return dec.BytesRead(), nil } @@ -278,8 +261,6 @@ func (d *Domain) AsyncReadFrom(r io.Reader) (int64, error, chan struct{}) { // twiddle factors d.preComputeTwiddles() - // store the bit reversed coset tables if needed - d.reverseCosetTables() close(chDone) }() diff --git a/ecc/bls12-377/fr/fft/fft.go b/ecc/bls12-377/fr/fft/fft.go index 91172620a..8c01bf23a 100644 --- a/ecc/bls12-377/fr/fft/fft.go +++ b/ecc/bls12-377/fr/fft/fft.go @@ -44,18 +44,22 @@ func (domain *Domain) FFT(a []fr.Element, decimation Decimation, opts ...Option) // if coset != 0, scale by coset table if opt.coset { - scale := func(cosetTable []fr.Element) { + if decimation == DIT { + // scale by coset table (in bit reversed order) parallel.Execute(len(a), func(start, end int) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) for i := start; i < end; i++ { - a[i].Mul(&a[i], &cosetTable[i]) + irev := int(bits.Reverse64(uint64(i)) >> nn) + a[i].Mul(&a[i], &domain.CosetTable[irev]) } }, opt.nbTasks) - } - if decimation == DIT { - scale(domain.CosetTableReversed) - } else { - scale(domain.CosetTable) + parallel.Execute(len(a), func(start, end int) { + for i := start; i < end; i++ { + a[i].Mul(&a[i], &domain.CosetTable[i]) + } + }, opt.nbTasks) } } @@ -109,21 +113,26 @@ func (domain *Domain) FFTInverse(a []fr.Element, decimation Decimation, opts ... return } - scale := func(cosetTable []fr.Element) { + if decimation == DIT { parallel.Execute(len(a), func(start, end int) { for i := start; i < end; i++ { - a[i].Mul(&a[i], &cosetTable[i]). + a[i].Mul(&a[i], &domain.CosetTableInv[i]). Mul(&a[i], &domain.CardinalityInv) } }, opt.nbTasks) - } - if decimation == DIT { - scale(domain.CosetTableInv) return } - // decimation == DIF - scale(domain.CosetTableInvReversed) + // decimation == DIF, need to access coset table in bit reversed order. + parallel.Execute(len(a), func(start, end int) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + for i := start; i < end; i++ { + irev := int(bits.Reverse64(uint64(i)) >> nn) + a[i].Mul(&a[i], &domain.CosetTableInv[irev]). + Mul(&a[i], &domain.CardinalityInv) + } + }, opt.nbTasks) } diff --git a/ecc/bls12-378/fr/fft/domain.go b/ecc/bls12-378/fr/fft/domain.go index e1bb5262b..2a8a00203 100644 --- a/ecc/bls12-378/fr/fft/domain.go +++ b/ecc/bls12-378/fr/fft/domain.go @@ -53,12 +53,10 @@ type Domain struct { // we precompute these mostly to avoid the memory intensive bit reverse permutation in the groth16.Prover // CosetTable u*<1,g,..,g^(n-1)> - CosetTable []fr.Element - CosetTableReversed []fr.Element // optional, this is computed on demand at the creation of the domain + CosetTable []fr.Element // CosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j - CosetTableInv []fr.Element - CosetTableInvReversed []fr.Element // optional, this is computed on demand at the creation of the domain + CosetTableInv []fr.Element } // NewDomain returns a subgroup with a power of 2 cardinality @@ -90,9 +88,6 @@ func NewDomain(m uint64, shift ...fr.Element) *Domain { // twiddle factors domain.preComputeTwiddles() - // store the bit reversed coset tables - domain.reverseCosetTables() - return domain } @@ -118,15 +113,6 @@ func Generator(m uint64) (fr.Element, error) { return generator, nil } -func (d *Domain) reverseCosetTables() { - d.CosetTableReversed = make([]fr.Element, d.Cardinality) - d.CosetTableInvReversed = make([]fr.Element, d.Cardinality) - copy(d.CosetTableReversed, d.CosetTable) - copy(d.CosetTableInvReversed, d.CosetTableInv) - BitReverse(d.CosetTableReversed) - BitReverse(d.CosetTableInvReversed) -} - func (d *Domain) preComputeTwiddles() { // nb fft stages @@ -252,9 +238,6 @@ func (d *Domain) ReadFrom(r io.Reader) (int64, error) { // twiddle factors d.preComputeTwiddles() - // store the bit reversed coset tables if needed - d.reverseCosetTables() - return dec.BytesRead(), nil } @@ -278,8 +261,6 @@ func (d *Domain) AsyncReadFrom(r io.Reader) (int64, error, chan struct{}) { // twiddle factors d.preComputeTwiddles() - // store the bit reversed coset tables if needed - d.reverseCosetTables() close(chDone) }() diff --git a/ecc/bls12-378/fr/fft/fft.go b/ecc/bls12-378/fr/fft/fft.go index 240828e0d..a74c8b4e8 100644 --- a/ecc/bls12-378/fr/fft/fft.go +++ b/ecc/bls12-378/fr/fft/fft.go @@ -44,18 +44,22 @@ func (domain *Domain) FFT(a []fr.Element, decimation Decimation, opts ...Option) // if coset != 0, scale by coset table if opt.coset { - scale := func(cosetTable []fr.Element) { + if decimation == DIT { + // scale by coset table (in bit reversed order) parallel.Execute(len(a), func(start, end int) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) for i := start; i < end; i++ { - a[i].Mul(&a[i], &cosetTable[i]) + irev := int(bits.Reverse64(uint64(i)) >> nn) + a[i].Mul(&a[i], &domain.CosetTable[irev]) } }, opt.nbTasks) - } - if decimation == DIT { - scale(domain.CosetTableReversed) - } else { - scale(domain.CosetTable) + parallel.Execute(len(a), func(start, end int) { + for i := start; i < end; i++ { + a[i].Mul(&a[i], &domain.CosetTable[i]) + } + }, opt.nbTasks) } } @@ -109,21 +113,26 @@ func (domain *Domain) FFTInverse(a []fr.Element, decimation Decimation, opts ... return } - scale := func(cosetTable []fr.Element) { + if decimation == DIT { parallel.Execute(len(a), func(start, end int) { for i := start; i < end; i++ { - a[i].Mul(&a[i], &cosetTable[i]). + a[i].Mul(&a[i], &domain.CosetTableInv[i]). Mul(&a[i], &domain.CardinalityInv) } }, opt.nbTasks) - } - if decimation == DIT { - scale(domain.CosetTableInv) return } - // decimation == DIF - scale(domain.CosetTableInvReversed) + // decimation == DIF, need to access coset table in bit reversed order. + parallel.Execute(len(a), func(start, end int) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + for i := start; i < end; i++ { + irev := int(bits.Reverse64(uint64(i)) >> nn) + a[i].Mul(&a[i], &domain.CosetTableInv[irev]). + Mul(&a[i], &domain.CardinalityInv) + } + }, opt.nbTasks) } diff --git a/ecc/bls12-381/fr/fft/domain.go b/ecc/bls12-381/fr/fft/domain.go index c4b1f8a61..37f5a1521 100644 --- a/ecc/bls12-381/fr/fft/domain.go +++ b/ecc/bls12-381/fr/fft/domain.go @@ -53,12 +53,10 @@ type Domain struct { // we precompute these mostly to avoid the memory intensive bit reverse permutation in the groth16.Prover // CosetTable u*<1,g,..,g^(n-1)> - CosetTable []fr.Element - CosetTableReversed []fr.Element // optional, this is computed on demand at the creation of the domain + CosetTable []fr.Element // CosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j - CosetTableInv []fr.Element - CosetTableInvReversed []fr.Element // optional, this is computed on demand at the creation of the domain + CosetTableInv []fr.Element } // NewDomain returns a subgroup with a power of 2 cardinality @@ -90,9 +88,6 @@ func NewDomain(m uint64, shift ...fr.Element) *Domain { // twiddle factors domain.preComputeTwiddles() - // store the bit reversed coset tables - domain.reverseCosetTables() - return domain } @@ -118,15 +113,6 @@ func Generator(m uint64) (fr.Element, error) { return generator, nil } -func (d *Domain) reverseCosetTables() { - d.CosetTableReversed = make([]fr.Element, d.Cardinality) - d.CosetTableInvReversed = make([]fr.Element, d.Cardinality) - copy(d.CosetTableReversed, d.CosetTable) - copy(d.CosetTableInvReversed, d.CosetTableInv) - BitReverse(d.CosetTableReversed) - BitReverse(d.CosetTableInvReversed) -} - func (d *Domain) preComputeTwiddles() { // nb fft stages @@ -252,9 +238,6 @@ func (d *Domain) ReadFrom(r io.Reader) (int64, error) { // twiddle factors d.preComputeTwiddles() - // store the bit reversed coset tables if needed - d.reverseCosetTables() - return dec.BytesRead(), nil } @@ -278,8 +261,6 @@ func (d *Domain) AsyncReadFrom(r io.Reader) (int64, error, chan struct{}) { // twiddle factors d.preComputeTwiddles() - // store the bit reversed coset tables if needed - d.reverseCosetTables() close(chDone) }() diff --git a/ecc/bls12-381/fr/fft/fft.go b/ecc/bls12-381/fr/fft/fft.go index 0bc140564..443a46bde 100644 --- a/ecc/bls12-381/fr/fft/fft.go +++ b/ecc/bls12-381/fr/fft/fft.go @@ -44,18 +44,22 @@ func (domain *Domain) FFT(a []fr.Element, decimation Decimation, opts ...Option) // if coset != 0, scale by coset table if opt.coset { - scale := func(cosetTable []fr.Element) { + if decimation == DIT { + // scale by coset table (in bit reversed order) parallel.Execute(len(a), func(start, end int) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) for i := start; i < end; i++ { - a[i].Mul(&a[i], &cosetTable[i]) + irev := int(bits.Reverse64(uint64(i)) >> nn) + a[i].Mul(&a[i], &domain.CosetTable[irev]) } }, opt.nbTasks) - } - if decimation == DIT { - scale(domain.CosetTableReversed) - } else { - scale(domain.CosetTable) + parallel.Execute(len(a), func(start, end int) { + for i := start; i < end; i++ { + a[i].Mul(&a[i], &domain.CosetTable[i]) + } + }, opt.nbTasks) } } @@ -109,21 +113,26 @@ func (domain *Domain) FFTInverse(a []fr.Element, decimation Decimation, opts ... return } - scale := func(cosetTable []fr.Element) { + if decimation == DIT { parallel.Execute(len(a), func(start, end int) { for i := start; i < end; i++ { - a[i].Mul(&a[i], &cosetTable[i]). + a[i].Mul(&a[i], &domain.CosetTableInv[i]). Mul(&a[i], &domain.CardinalityInv) } }, opt.nbTasks) - } - if decimation == DIT { - scale(domain.CosetTableInv) return } - // decimation == DIF - scale(domain.CosetTableInvReversed) + // decimation == DIF, need to access coset table in bit reversed order. + parallel.Execute(len(a), func(start, end int) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + for i := start; i < end; i++ { + irev := int(bits.Reverse64(uint64(i)) >> nn) + a[i].Mul(&a[i], &domain.CosetTableInv[irev]). + Mul(&a[i], &domain.CardinalityInv) + } + }, opt.nbTasks) } diff --git a/ecc/bls24-315/fr/fft/domain.go b/ecc/bls24-315/fr/fft/domain.go index fdf95818f..7316234ea 100644 --- a/ecc/bls24-315/fr/fft/domain.go +++ b/ecc/bls24-315/fr/fft/domain.go @@ -53,12 +53,10 @@ type Domain struct { // we precompute these mostly to avoid the memory intensive bit reverse permutation in the groth16.Prover // CosetTable u*<1,g,..,g^(n-1)> - CosetTable []fr.Element - CosetTableReversed []fr.Element // optional, this is computed on demand at the creation of the domain + CosetTable []fr.Element // CosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j - CosetTableInv []fr.Element - CosetTableInvReversed []fr.Element // optional, this is computed on demand at the creation of the domain + CosetTableInv []fr.Element } // NewDomain returns a subgroup with a power of 2 cardinality @@ -90,9 +88,6 @@ func NewDomain(m uint64, shift ...fr.Element) *Domain { // twiddle factors domain.preComputeTwiddles() - // store the bit reversed coset tables - domain.reverseCosetTables() - return domain } @@ -118,15 +113,6 @@ func Generator(m uint64) (fr.Element, error) { return generator, nil } -func (d *Domain) reverseCosetTables() { - d.CosetTableReversed = make([]fr.Element, d.Cardinality) - d.CosetTableInvReversed = make([]fr.Element, d.Cardinality) - copy(d.CosetTableReversed, d.CosetTable) - copy(d.CosetTableInvReversed, d.CosetTableInv) - BitReverse(d.CosetTableReversed) - BitReverse(d.CosetTableInvReversed) -} - func (d *Domain) preComputeTwiddles() { // nb fft stages @@ -252,9 +238,6 @@ func (d *Domain) ReadFrom(r io.Reader) (int64, error) { // twiddle factors d.preComputeTwiddles() - // store the bit reversed coset tables if needed - d.reverseCosetTables() - return dec.BytesRead(), nil } @@ -278,8 +261,6 @@ func (d *Domain) AsyncReadFrom(r io.Reader) (int64, error, chan struct{}) { // twiddle factors d.preComputeTwiddles() - // store the bit reversed coset tables if needed - d.reverseCosetTables() close(chDone) }() diff --git a/ecc/bls24-315/fr/fft/fft.go b/ecc/bls24-315/fr/fft/fft.go index b897b1d4d..30fb173cf 100644 --- a/ecc/bls24-315/fr/fft/fft.go +++ b/ecc/bls24-315/fr/fft/fft.go @@ -44,18 +44,22 @@ func (domain *Domain) FFT(a []fr.Element, decimation Decimation, opts ...Option) // if coset != 0, scale by coset table if opt.coset { - scale := func(cosetTable []fr.Element) { + if decimation == DIT { + // scale by coset table (in bit reversed order) parallel.Execute(len(a), func(start, end int) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) for i := start; i < end; i++ { - a[i].Mul(&a[i], &cosetTable[i]) + irev := int(bits.Reverse64(uint64(i)) >> nn) + a[i].Mul(&a[i], &domain.CosetTable[irev]) } }, opt.nbTasks) - } - if decimation == DIT { - scale(domain.CosetTableReversed) - } else { - scale(domain.CosetTable) + parallel.Execute(len(a), func(start, end int) { + for i := start; i < end; i++ { + a[i].Mul(&a[i], &domain.CosetTable[i]) + } + }, opt.nbTasks) } } @@ -109,21 +113,26 @@ func (domain *Domain) FFTInverse(a []fr.Element, decimation Decimation, opts ... return } - scale := func(cosetTable []fr.Element) { + if decimation == DIT { parallel.Execute(len(a), func(start, end int) { for i := start; i < end; i++ { - a[i].Mul(&a[i], &cosetTable[i]). + a[i].Mul(&a[i], &domain.CosetTableInv[i]). Mul(&a[i], &domain.CardinalityInv) } }, opt.nbTasks) - } - if decimation == DIT { - scale(domain.CosetTableInv) return } - // decimation == DIF - scale(domain.CosetTableInvReversed) + // decimation == DIF, need to access coset table in bit reversed order. + parallel.Execute(len(a), func(start, end int) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + for i := start; i < end; i++ { + irev := int(bits.Reverse64(uint64(i)) >> nn) + a[i].Mul(&a[i], &domain.CosetTableInv[irev]). + Mul(&a[i], &domain.CardinalityInv) + } + }, opt.nbTasks) } diff --git a/ecc/bls24-317/fr/fft/domain.go b/ecc/bls24-317/fr/fft/domain.go index fb3eab9cc..de7e62550 100644 --- a/ecc/bls24-317/fr/fft/domain.go +++ b/ecc/bls24-317/fr/fft/domain.go @@ -53,12 +53,10 @@ type Domain struct { // we precompute these mostly to avoid the memory intensive bit reverse permutation in the groth16.Prover // CosetTable u*<1,g,..,g^(n-1)> - CosetTable []fr.Element - CosetTableReversed []fr.Element // optional, this is computed on demand at the creation of the domain + CosetTable []fr.Element // CosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j - CosetTableInv []fr.Element - CosetTableInvReversed []fr.Element // optional, this is computed on demand at the creation of the domain + CosetTableInv []fr.Element } // NewDomain returns a subgroup with a power of 2 cardinality @@ -90,9 +88,6 @@ func NewDomain(m uint64, shift ...fr.Element) *Domain { // twiddle factors domain.preComputeTwiddles() - // store the bit reversed coset tables - domain.reverseCosetTables() - return domain } @@ -118,15 +113,6 @@ func Generator(m uint64) (fr.Element, error) { return generator, nil } -func (d *Domain) reverseCosetTables() { - d.CosetTableReversed = make([]fr.Element, d.Cardinality) - d.CosetTableInvReversed = make([]fr.Element, d.Cardinality) - copy(d.CosetTableReversed, d.CosetTable) - copy(d.CosetTableInvReversed, d.CosetTableInv) - BitReverse(d.CosetTableReversed) - BitReverse(d.CosetTableInvReversed) -} - func (d *Domain) preComputeTwiddles() { // nb fft stages @@ -252,9 +238,6 @@ func (d *Domain) ReadFrom(r io.Reader) (int64, error) { // twiddle factors d.preComputeTwiddles() - // store the bit reversed coset tables if needed - d.reverseCosetTables() - return dec.BytesRead(), nil } @@ -278,8 +261,6 @@ func (d *Domain) AsyncReadFrom(r io.Reader) (int64, error, chan struct{}) { // twiddle factors d.preComputeTwiddles() - // store the bit reversed coset tables if needed - d.reverseCosetTables() close(chDone) }() diff --git a/ecc/bls24-317/fr/fft/fft.go b/ecc/bls24-317/fr/fft/fft.go index 344c171ee..cf4230ed1 100644 --- a/ecc/bls24-317/fr/fft/fft.go +++ b/ecc/bls24-317/fr/fft/fft.go @@ -44,18 +44,22 @@ func (domain *Domain) FFT(a []fr.Element, decimation Decimation, opts ...Option) // if coset != 0, scale by coset table if opt.coset { - scale := func(cosetTable []fr.Element) { + if decimation == DIT { + // scale by coset table (in bit reversed order) parallel.Execute(len(a), func(start, end int) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) for i := start; i < end; i++ { - a[i].Mul(&a[i], &cosetTable[i]) + irev := int(bits.Reverse64(uint64(i)) >> nn) + a[i].Mul(&a[i], &domain.CosetTable[irev]) } }, opt.nbTasks) - } - if decimation == DIT { - scale(domain.CosetTableReversed) - } else { - scale(domain.CosetTable) + parallel.Execute(len(a), func(start, end int) { + for i := start; i < end; i++ { + a[i].Mul(&a[i], &domain.CosetTable[i]) + } + }, opt.nbTasks) } } @@ -109,21 +113,26 @@ func (domain *Domain) FFTInverse(a []fr.Element, decimation Decimation, opts ... return } - scale := func(cosetTable []fr.Element) { + if decimation == DIT { parallel.Execute(len(a), func(start, end int) { for i := start; i < end; i++ { - a[i].Mul(&a[i], &cosetTable[i]). + a[i].Mul(&a[i], &domain.CosetTableInv[i]). Mul(&a[i], &domain.CardinalityInv) } }, opt.nbTasks) - } - if decimation == DIT { - scale(domain.CosetTableInv) return } - // decimation == DIF - scale(domain.CosetTableInvReversed) + // decimation == DIF, need to access coset table in bit reversed order. + parallel.Execute(len(a), func(start, end int) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + for i := start; i < end; i++ { + irev := int(bits.Reverse64(uint64(i)) >> nn) + a[i].Mul(&a[i], &domain.CosetTableInv[irev]). + Mul(&a[i], &domain.CardinalityInv) + } + }, opt.nbTasks) } diff --git a/ecc/bn254/fr/fft/domain.go b/ecc/bn254/fr/fft/domain.go index 2926d60ad..ea657d048 100644 --- a/ecc/bn254/fr/fft/domain.go +++ b/ecc/bn254/fr/fft/domain.go @@ -53,12 +53,10 @@ type Domain struct { // we precompute these mostly to avoid the memory intensive bit reverse permutation in the groth16.Prover // CosetTable u*<1,g,..,g^(n-1)> - CosetTable []fr.Element - CosetTableReversed []fr.Element // optional, this is computed on demand at the creation of the domain + CosetTable []fr.Element // CosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j - CosetTableInv []fr.Element - CosetTableInvReversed []fr.Element // optional, this is computed on demand at the creation of the domain + CosetTableInv []fr.Element } // NewDomain returns a subgroup with a power of 2 cardinality @@ -90,9 +88,6 @@ func NewDomain(m uint64, shift ...fr.Element) *Domain { // twiddle factors domain.preComputeTwiddles() - // store the bit reversed coset tables - domain.reverseCosetTables() - return domain } @@ -118,15 +113,6 @@ func Generator(m uint64) (fr.Element, error) { return generator, nil } -func (d *Domain) reverseCosetTables() { - d.CosetTableReversed = make([]fr.Element, d.Cardinality) - d.CosetTableInvReversed = make([]fr.Element, d.Cardinality) - copy(d.CosetTableReversed, d.CosetTable) - copy(d.CosetTableInvReversed, d.CosetTableInv) - BitReverse(d.CosetTableReversed) - BitReverse(d.CosetTableInvReversed) -} - func (d *Domain) preComputeTwiddles() { // nb fft stages @@ -252,9 +238,6 @@ func (d *Domain) ReadFrom(r io.Reader) (int64, error) { // twiddle factors d.preComputeTwiddles() - // store the bit reversed coset tables if needed - d.reverseCosetTables() - return dec.BytesRead(), nil } @@ -278,8 +261,6 @@ func (d *Domain) AsyncReadFrom(r io.Reader) (int64, error, chan struct{}) { // twiddle factors d.preComputeTwiddles() - // store the bit reversed coset tables if needed - d.reverseCosetTables() close(chDone) }() diff --git a/ecc/bn254/fr/fft/fft.go b/ecc/bn254/fr/fft/fft.go index e38674974..151b7832f 100644 --- a/ecc/bn254/fr/fft/fft.go +++ b/ecc/bn254/fr/fft/fft.go @@ -44,18 +44,22 @@ func (domain *Domain) FFT(a []fr.Element, decimation Decimation, opts ...Option) // if coset != 0, scale by coset table if opt.coset { - scale := func(cosetTable []fr.Element) { + if decimation == DIT { + // scale by coset table (in bit reversed order) parallel.Execute(len(a), func(start, end int) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) for i := start; i < end; i++ { - a[i].Mul(&a[i], &cosetTable[i]) + irev := int(bits.Reverse64(uint64(i)) >> nn) + a[i].Mul(&a[i], &domain.CosetTable[irev]) } }, opt.nbTasks) - } - if decimation == DIT { - scale(domain.CosetTableReversed) - } else { - scale(domain.CosetTable) + parallel.Execute(len(a), func(start, end int) { + for i := start; i < end; i++ { + a[i].Mul(&a[i], &domain.CosetTable[i]) + } + }, opt.nbTasks) } } @@ -109,21 +113,26 @@ func (domain *Domain) FFTInverse(a []fr.Element, decimation Decimation, opts ... return } - scale := func(cosetTable []fr.Element) { + if decimation == DIT { parallel.Execute(len(a), func(start, end int) { for i := start; i < end; i++ { - a[i].Mul(&a[i], &cosetTable[i]). + a[i].Mul(&a[i], &domain.CosetTableInv[i]). Mul(&a[i], &domain.CardinalityInv) } }, opt.nbTasks) - } - if decimation == DIT { - scale(domain.CosetTableInv) return } - // decimation == DIF - scale(domain.CosetTableInvReversed) + // decimation == DIF, need to access coset table in bit reversed order. + parallel.Execute(len(a), func(start, end int) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + for i := start; i < end; i++ { + irev := int(bits.Reverse64(uint64(i)) >> nn) + a[i].Mul(&a[i], &domain.CosetTableInv[irev]). + Mul(&a[i], &domain.CardinalityInv) + } + }, opt.nbTasks) } diff --git a/ecc/bw6-633/fr/fft/domain.go b/ecc/bw6-633/fr/fft/domain.go index bb5326b82..040015ccc 100644 --- a/ecc/bw6-633/fr/fft/domain.go +++ b/ecc/bw6-633/fr/fft/domain.go @@ -53,12 +53,10 @@ type Domain struct { // we precompute these mostly to avoid the memory intensive bit reverse permutation in the groth16.Prover // CosetTable u*<1,g,..,g^(n-1)> - CosetTable []fr.Element - CosetTableReversed []fr.Element // optional, this is computed on demand at the creation of the domain + CosetTable []fr.Element // CosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j - CosetTableInv []fr.Element - CosetTableInvReversed []fr.Element // optional, this is computed on demand at the creation of the domain + CosetTableInv []fr.Element } // NewDomain returns a subgroup with a power of 2 cardinality @@ -90,9 +88,6 @@ func NewDomain(m uint64, shift ...fr.Element) *Domain { // twiddle factors domain.preComputeTwiddles() - // store the bit reversed coset tables - domain.reverseCosetTables() - return domain } @@ -118,15 +113,6 @@ func Generator(m uint64) (fr.Element, error) { return generator, nil } -func (d *Domain) reverseCosetTables() { - d.CosetTableReversed = make([]fr.Element, d.Cardinality) - d.CosetTableInvReversed = make([]fr.Element, d.Cardinality) - copy(d.CosetTableReversed, d.CosetTable) - copy(d.CosetTableInvReversed, d.CosetTableInv) - BitReverse(d.CosetTableReversed) - BitReverse(d.CosetTableInvReversed) -} - func (d *Domain) preComputeTwiddles() { // nb fft stages @@ -252,9 +238,6 @@ func (d *Domain) ReadFrom(r io.Reader) (int64, error) { // twiddle factors d.preComputeTwiddles() - // store the bit reversed coset tables if needed - d.reverseCosetTables() - return dec.BytesRead(), nil } @@ -278,8 +261,6 @@ func (d *Domain) AsyncReadFrom(r io.Reader) (int64, error, chan struct{}) { // twiddle factors d.preComputeTwiddles() - // store the bit reversed coset tables if needed - d.reverseCosetTables() close(chDone) }() diff --git a/ecc/bw6-633/fr/fft/fft.go b/ecc/bw6-633/fr/fft/fft.go index 522e0ebb1..7485014f1 100644 --- a/ecc/bw6-633/fr/fft/fft.go +++ b/ecc/bw6-633/fr/fft/fft.go @@ -44,18 +44,22 @@ func (domain *Domain) FFT(a []fr.Element, decimation Decimation, opts ...Option) // if coset != 0, scale by coset table if opt.coset { - scale := func(cosetTable []fr.Element) { + if decimation == DIT { + // scale by coset table (in bit reversed order) parallel.Execute(len(a), func(start, end int) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) for i := start; i < end; i++ { - a[i].Mul(&a[i], &cosetTable[i]) + irev := int(bits.Reverse64(uint64(i)) >> nn) + a[i].Mul(&a[i], &domain.CosetTable[irev]) } }, opt.nbTasks) - } - if decimation == DIT { - scale(domain.CosetTableReversed) - } else { - scale(domain.CosetTable) + parallel.Execute(len(a), func(start, end int) { + for i := start; i < end; i++ { + a[i].Mul(&a[i], &domain.CosetTable[i]) + } + }, opt.nbTasks) } } @@ -109,21 +113,26 @@ func (domain *Domain) FFTInverse(a []fr.Element, decimation Decimation, opts ... return } - scale := func(cosetTable []fr.Element) { + if decimation == DIT { parallel.Execute(len(a), func(start, end int) { for i := start; i < end; i++ { - a[i].Mul(&a[i], &cosetTable[i]). + a[i].Mul(&a[i], &domain.CosetTableInv[i]). Mul(&a[i], &domain.CardinalityInv) } }, opt.nbTasks) - } - if decimation == DIT { - scale(domain.CosetTableInv) return } - // decimation == DIF - scale(domain.CosetTableInvReversed) + // decimation == DIF, need to access coset table in bit reversed order. + parallel.Execute(len(a), func(start, end int) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + for i := start; i < end; i++ { + irev := int(bits.Reverse64(uint64(i)) >> nn) + a[i].Mul(&a[i], &domain.CosetTableInv[irev]). + Mul(&a[i], &domain.CardinalityInv) + } + }, opt.nbTasks) } diff --git a/ecc/bw6-756/fr/fft/domain.go b/ecc/bw6-756/fr/fft/domain.go index 632d3392a..e71e55f77 100644 --- a/ecc/bw6-756/fr/fft/domain.go +++ b/ecc/bw6-756/fr/fft/domain.go @@ -53,12 +53,10 @@ type Domain struct { // we precompute these mostly to avoid the memory intensive bit reverse permutation in the groth16.Prover // CosetTable u*<1,g,..,g^(n-1)> - CosetTable []fr.Element - CosetTableReversed []fr.Element // optional, this is computed on demand at the creation of the domain + CosetTable []fr.Element // CosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j - CosetTableInv []fr.Element - CosetTableInvReversed []fr.Element // optional, this is computed on demand at the creation of the domain + CosetTableInv []fr.Element } // NewDomain returns a subgroup with a power of 2 cardinality @@ -90,9 +88,6 @@ func NewDomain(m uint64, shift ...fr.Element) *Domain { // twiddle factors domain.preComputeTwiddles() - // store the bit reversed coset tables - domain.reverseCosetTables() - return domain } @@ -118,15 +113,6 @@ func Generator(m uint64) (fr.Element, error) { return generator, nil } -func (d *Domain) reverseCosetTables() { - d.CosetTableReversed = make([]fr.Element, d.Cardinality) - d.CosetTableInvReversed = make([]fr.Element, d.Cardinality) - copy(d.CosetTableReversed, d.CosetTable) - copy(d.CosetTableInvReversed, d.CosetTableInv) - BitReverse(d.CosetTableReversed) - BitReverse(d.CosetTableInvReversed) -} - func (d *Domain) preComputeTwiddles() { // nb fft stages @@ -252,9 +238,6 @@ func (d *Domain) ReadFrom(r io.Reader) (int64, error) { // twiddle factors d.preComputeTwiddles() - // store the bit reversed coset tables if needed - d.reverseCosetTables() - return dec.BytesRead(), nil } @@ -278,8 +261,6 @@ func (d *Domain) AsyncReadFrom(r io.Reader) (int64, error, chan struct{}) { // twiddle factors d.preComputeTwiddles() - // store the bit reversed coset tables if needed - d.reverseCosetTables() close(chDone) }() diff --git a/ecc/bw6-756/fr/fft/fft.go b/ecc/bw6-756/fr/fft/fft.go index de63d18a2..5fd501c20 100644 --- a/ecc/bw6-756/fr/fft/fft.go +++ b/ecc/bw6-756/fr/fft/fft.go @@ -44,18 +44,22 @@ func (domain *Domain) FFT(a []fr.Element, decimation Decimation, opts ...Option) // if coset != 0, scale by coset table if opt.coset { - scale := func(cosetTable []fr.Element) { + if decimation == DIT { + // scale by coset table (in bit reversed order) parallel.Execute(len(a), func(start, end int) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) for i := start; i < end; i++ { - a[i].Mul(&a[i], &cosetTable[i]) + irev := int(bits.Reverse64(uint64(i)) >> nn) + a[i].Mul(&a[i], &domain.CosetTable[irev]) } }, opt.nbTasks) - } - if decimation == DIT { - scale(domain.CosetTableReversed) - } else { - scale(domain.CosetTable) + parallel.Execute(len(a), func(start, end int) { + for i := start; i < end; i++ { + a[i].Mul(&a[i], &domain.CosetTable[i]) + } + }, opt.nbTasks) } } @@ -109,21 +113,26 @@ func (domain *Domain) FFTInverse(a []fr.Element, decimation Decimation, opts ... return } - scale := func(cosetTable []fr.Element) { + if decimation == DIT { parallel.Execute(len(a), func(start, end int) { for i := start; i < end; i++ { - a[i].Mul(&a[i], &cosetTable[i]). + a[i].Mul(&a[i], &domain.CosetTableInv[i]). Mul(&a[i], &domain.CardinalityInv) } }, opt.nbTasks) - } - if decimation == DIT { - scale(domain.CosetTableInv) return } - // decimation == DIF - scale(domain.CosetTableInvReversed) + // decimation == DIF, need to access coset table in bit reversed order. + parallel.Execute(len(a), func(start, end int) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + for i := start; i < end; i++ { + irev := int(bits.Reverse64(uint64(i)) >> nn) + a[i].Mul(&a[i], &domain.CosetTableInv[irev]). + Mul(&a[i], &domain.CardinalityInv) + } + }, opt.nbTasks) } diff --git a/ecc/bw6-761/fr/fft/domain.go b/ecc/bw6-761/fr/fft/domain.go index b36b27cf2..8042e60ff 100644 --- a/ecc/bw6-761/fr/fft/domain.go +++ b/ecc/bw6-761/fr/fft/domain.go @@ -53,12 +53,10 @@ type Domain struct { // we precompute these mostly to avoid the memory intensive bit reverse permutation in the groth16.Prover // CosetTable u*<1,g,..,g^(n-1)> - CosetTable []fr.Element - CosetTableReversed []fr.Element // optional, this is computed on demand at the creation of the domain + CosetTable []fr.Element // CosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j - CosetTableInv []fr.Element - CosetTableInvReversed []fr.Element // optional, this is computed on demand at the creation of the domain + CosetTableInv []fr.Element } // NewDomain returns a subgroup with a power of 2 cardinality @@ -90,9 +88,6 @@ func NewDomain(m uint64, shift ...fr.Element) *Domain { // twiddle factors domain.preComputeTwiddles() - // store the bit reversed coset tables - domain.reverseCosetTables() - return domain } @@ -118,15 +113,6 @@ func Generator(m uint64) (fr.Element, error) { return generator, nil } -func (d *Domain) reverseCosetTables() { - d.CosetTableReversed = make([]fr.Element, d.Cardinality) - d.CosetTableInvReversed = make([]fr.Element, d.Cardinality) - copy(d.CosetTableReversed, d.CosetTable) - copy(d.CosetTableInvReversed, d.CosetTableInv) - BitReverse(d.CosetTableReversed) - BitReverse(d.CosetTableInvReversed) -} - func (d *Domain) preComputeTwiddles() { // nb fft stages @@ -252,9 +238,6 @@ func (d *Domain) ReadFrom(r io.Reader) (int64, error) { // twiddle factors d.preComputeTwiddles() - // store the bit reversed coset tables if needed - d.reverseCosetTables() - return dec.BytesRead(), nil } @@ -278,8 +261,6 @@ func (d *Domain) AsyncReadFrom(r io.Reader) (int64, error, chan struct{}) { // twiddle factors d.preComputeTwiddles() - // store the bit reversed coset tables if needed - d.reverseCosetTables() close(chDone) }() diff --git a/ecc/bw6-761/fr/fft/fft.go b/ecc/bw6-761/fr/fft/fft.go index d42c7c939..5c05decd1 100644 --- a/ecc/bw6-761/fr/fft/fft.go +++ b/ecc/bw6-761/fr/fft/fft.go @@ -44,18 +44,22 @@ func (domain *Domain) FFT(a []fr.Element, decimation Decimation, opts ...Option) // if coset != 0, scale by coset table if opt.coset { - scale := func(cosetTable []fr.Element) { + if decimation == DIT { + // scale by coset table (in bit reversed order) parallel.Execute(len(a), func(start, end int) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) for i := start; i < end; i++ { - a[i].Mul(&a[i], &cosetTable[i]) + irev := int(bits.Reverse64(uint64(i)) >> nn) + a[i].Mul(&a[i], &domain.CosetTable[irev]) } }, opt.nbTasks) - } - if decimation == DIT { - scale(domain.CosetTableReversed) - } else { - scale(domain.CosetTable) + parallel.Execute(len(a), func(start, end int) { + for i := start; i < end; i++ { + a[i].Mul(&a[i], &domain.CosetTable[i]) + } + }, opt.nbTasks) } } @@ -109,21 +113,26 @@ func (domain *Domain) FFTInverse(a []fr.Element, decimation Decimation, opts ... return } - scale := func(cosetTable []fr.Element) { + if decimation == DIT { parallel.Execute(len(a), func(start, end int) { for i := start; i < end; i++ { - a[i].Mul(&a[i], &cosetTable[i]). + a[i].Mul(&a[i], &domain.CosetTableInv[i]). Mul(&a[i], &domain.CardinalityInv) } }, opt.nbTasks) - } - if decimation == DIT { - scale(domain.CosetTableInv) return } - // decimation == DIF - scale(domain.CosetTableInvReversed) + // decimation == DIF, need to access coset table in bit reversed order. + parallel.Execute(len(a), func(start, end int) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + for i := start; i < end; i++ { + irev := int(bits.Reverse64(uint64(i)) >> nn) + a[i].Mul(&a[i], &domain.CosetTableInv[irev]). + Mul(&a[i], &domain.CardinalityInv) + } + }, opt.nbTasks) } diff --git a/internal/generator/fft/template/domain.go.tmpl b/internal/generator/fft/template/domain.go.tmpl index e2aafc9e6..14876d8aa 100644 --- a/internal/generator/fft/template/domain.go.tmpl +++ b/internal/generator/fft/template/domain.go.tmpl @@ -35,11 +35,9 @@ type Domain struct { // CosetTable u*<1,g,..,g^(n-1)> CosetTable []fr.Element - CosetTableReversed []fr.Element // optional, this is computed on demand at the creation of the domain // CosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j CosetTableInv []fr.Element - CosetTableInvReversed []fr.Element // optional, this is computed on demand at the creation of the domain } @@ -89,9 +87,6 @@ func NewDomain(m uint64, shift ...fr.Element) *Domain { // twiddle factors domain.preComputeTwiddles() - // store the bit reversed coset tables - domain.reverseCosetTables() - return domain } @@ -142,15 +137,6 @@ func Generator(m uint64) (fr.Element, error) { return generator, nil } -func (d *Domain) reverseCosetTables() { - d.CosetTableReversed = make([]fr.Element, d.Cardinality) - d.CosetTableInvReversed = make([]fr.Element, d.Cardinality) - copy(d.CosetTableReversed, d.CosetTable) - copy(d.CosetTableInvReversed, d.CosetTableInv) - BitReverse(d.CosetTableReversed) - BitReverse(d.CosetTableInvReversed) -} - func (d *Domain) preComputeTwiddles() { // nb fft stages @@ -277,9 +263,6 @@ func (d *Domain) ReadFrom(r io.Reader) (int64, error) { // twiddle factors d.preComputeTwiddles() - // store the bit reversed coset tables if needed - d.reverseCosetTables() - return dec.BytesRead(), nil } @@ -303,8 +286,6 @@ func (d *Domain) AsyncReadFrom(r io.Reader) (int64, error, chan struct{}) { // twiddle factors d.preComputeTwiddles() - // store the bit reversed coset tables if needed - d.reverseCosetTables() close(chDone) }() diff --git a/internal/generator/fft/template/fft.go.tmpl b/internal/generator/fft/template/fft.go.tmpl index bea58e0c5..90729bd6e 100644 --- a/internal/generator/fft/template/fft.go.tmpl +++ b/internal/generator/fft/template/fft.go.tmpl @@ -26,18 +26,22 @@ func (domain *Domain) FFT(a []fr.Element, decimation Decimation, opts ...Option) // if coset != 0, scale by coset table if opt.coset { - scale := func(cosetTable []fr.Element) { + if decimation == DIT { + // scale by coset table (in bit reversed order) parallel.Execute(len(a), func(start, end int) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) for i := start; i < end; i++ { - a[i].Mul(&a[i], &cosetTable[i]) + irev := int(bits.Reverse64(uint64(i)) >> nn) + a[i].Mul(&a[i], &domain.CosetTable[irev]) } }, opt.nbTasks) - } - if decimation == DIT { - scale(domain.CosetTableReversed) - } else { - scale(domain.CosetTable) + parallel.Execute(len(a), func(start, end int) { + for i := start; i < end; i++ { + a[i].Mul(&a[i], &domain.CosetTable[i]) + } + }, opt.nbTasks) } } @@ -91,21 +95,27 @@ func (domain *Domain) FFTInverse(a []fr.Element, decimation Decimation, opts ... return } - scale := func(cosetTable []fr.Element) { + + if decimation == DIT { parallel.Execute(len(a), func(start, end int) { for i := start; i < end; i++ { - a[i].Mul(&a[i], &cosetTable[i]). + a[i].Mul(&a[i], &domain.CosetTableInv[i]). Mul(&a[i], &domain.CardinalityInv) } }, opt.nbTasks) - } - if decimation == DIT { - scale(domain.CosetTableInv) return } - // decimation == DIF - scale(domain.CosetTableInvReversed) + // decimation == DIF, need to access coset table in bit reversed order. + parallel.Execute(len(a), func(start, end int) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + for i := start; i < end; i++ { + irev := int(bits.Reverse64(uint64(i)) >> nn) + a[i].Mul(&a[i], &domain.CosetTableInv[irev]). + Mul(&a[i], &domain.CardinalityInv) + } + }, opt.nbTasks) }