diff --git a/backend/groth16/bls12-377/mpcsetup/lagrange.go b/backend/groth16/bls12-377/mpcsetup/lagrange.go index be7a1d0b27..2d7043c9a6 100644 --- a/backend/groth16/bls12-377/mpcsetup/lagrange.go +++ b/backend/groth16/bls12-377/mpcsetup/lagrange.go @@ -36,7 +36,8 @@ func lagrangeCoeffsG1(powers []curve.G1Affine, size int) []curve.G1Affine { numCPU := uint64(runtime.NumCPU()) maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) - difFFTG1(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + twiddlesInv, _ := domain.TwiddlesInv() + difFFTG1(coeffs, twiddlesInv, 0, maxSplits, nil) bitReverse(coeffs) var invBigint big.Int @@ -58,7 +59,8 @@ func lagrangeCoeffsG2(powers []curve.G2Affine, size int) []curve.G2Affine { numCPU := uint64(runtime.NumCPU()) maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) - difFFTG2(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + twiddlesInv, _ := domain.TwiddlesInv() + difFFTG2(coeffs, twiddlesInv, 0, maxSplits, nil) bitReverse(coeffs) var invBigint big.Int diff --git a/backend/groth16/bls12-381/mpcsetup/lagrange.go b/backend/groth16/bls12-381/mpcsetup/lagrange.go index 700d3b31b6..efd77055e1 100644 --- a/backend/groth16/bls12-381/mpcsetup/lagrange.go +++ b/backend/groth16/bls12-381/mpcsetup/lagrange.go @@ -36,7 +36,8 @@ func lagrangeCoeffsG1(powers []curve.G1Affine, size int) []curve.G1Affine { numCPU := uint64(runtime.NumCPU()) maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) - difFFTG1(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + twiddlesInv, _ := domain.TwiddlesInv() + difFFTG1(coeffs, twiddlesInv, 0, maxSplits, nil) bitReverse(coeffs) var invBigint big.Int @@ -58,7 +59,8 @@ func lagrangeCoeffsG2(powers []curve.G2Affine, size int) []curve.G2Affine { numCPU := uint64(runtime.NumCPU()) maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) - difFFTG2(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + twiddlesInv, _ := domain.TwiddlesInv() + difFFTG2(coeffs, twiddlesInv, 0, maxSplits, nil) bitReverse(coeffs) var invBigint big.Int diff --git a/backend/groth16/bls24-315/mpcsetup/lagrange.go b/backend/groth16/bls24-315/mpcsetup/lagrange.go index 9f8cee5768..01cf0bb7f4 100644 --- a/backend/groth16/bls24-315/mpcsetup/lagrange.go +++ b/backend/groth16/bls24-315/mpcsetup/lagrange.go @@ -36,7 +36,8 @@ func lagrangeCoeffsG1(powers []curve.G1Affine, size int) []curve.G1Affine { numCPU := uint64(runtime.NumCPU()) maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) - difFFTG1(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + twiddlesInv, _ := domain.TwiddlesInv() + difFFTG1(coeffs, twiddlesInv, 0, maxSplits, nil) bitReverse(coeffs) var invBigint big.Int @@ -58,7 +59,8 @@ func lagrangeCoeffsG2(powers []curve.G2Affine, size int) []curve.G2Affine { numCPU := uint64(runtime.NumCPU()) maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) - difFFTG2(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + twiddlesInv, _ := domain.TwiddlesInv() + difFFTG2(coeffs, twiddlesInv, 0, maxSplits, nil) bitReverse(coeffs) var invBigint big.Int diff --git a/backend/groth16/bls24-317/mpcsetup/lagrange.go b/backend/groth16/bls24-317/mpcsetup/lagrange.go index 87502a5c09..ea2c29edd3 100644 --- a/backend/groth16/bls24-317/mpcsetup/lagrange.go +++ b/backend/groth16/bls24-317/mpcsetup/lagrange.go @@ -36,7 +36,8 @@ func lagrangeCoeffsG1(powers []curve.G1Affine, size int) []curve.G1Affine { numCPU := uint64(runtime.NumCPU()) maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) - difFFTG1(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + twiddlesInv, _ := domain.TwiddlesInv() + difFFTG1(coeffs, twiddlesInv, 0, maxSplits, nil) bitReverse(coeffs) var invBigint big.Int @@ -58,7 +59,8 @@ func lagrangeCoeffsG2(powers []curve.G2Affine, size int) []curve.G2Affine { numCPU := uint64(runtime.NumCPU()) maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) - difFFTG2(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + twiddlesInv, _ := domain.TwiddlesInv() + difFFTG2(coeffs, twiddlesInv, 0, maxSplits, nil) bitReverse(coeffs) var invBigint big.Int diff --git a/backend/groth16/bn254/mpcsetup/lagrange.go b/backend/groth16/bn254/mpcsetup/lagrange.go index ffa21be073..886e489248 100644 --- a/backend/groth16/bn254/mpcsetup/lagrange.go +++ b/backend/groth16/bn254/mpcsetup/lagrange.go @@ -36,7 +36,8 @@ func lagrangeCoeffsG1(powers []curve.G1Affine, size int) []curve.G1Affine { numCPU := uint64(runtime.NumCPU()) maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) - difFFTG1(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + twiddlesInv, _ := domain.TwiddlesInv() + difFFTG1(coeffs, twiddlesInv, 0, maxSplits, nil) bitReverse(coeffs) var invBigint big.Int @@ -58,7 +59,8 @@ func lagrangeCoeffsG2(powers []curve.G2Affine, size int) []curve.G2Affine { numCPU := uint64(runtime.NumCPU()) maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) - difFFTG2(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + twiddlesInv, _ := domain.TwiddlesInv() + difFFTG2(coeffs, twiddlesInv, 0, maxSplits, nil) bitReverse(coeffs) var invBigint big.Int diff --git a/backend/groth16/bw6-633/mpcsetup/lagrange.go b/backend/groth16/bw6-633/mpcsetup/lagrange.go index 0d0e87309e..4584c3964a 100644 --- a/backend/groth16/bw6-633/mpcsetup/lagrange.go +++ b/backend/groth16/bw6-633/mpcsetup/lagrange.go @@ -36,7 +36,8 @@ func lagrangeCoeffsG1(powers []curve.G1Affine, size int) []curve.G1Affine { numCPU := uint64(runtime.NumCPU()) maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) - difFFTG1(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + twiddlesInv, _ := domain.TwiddlesInv() + difFFTG1(coeffs, twiddlesInv, 0, maxSplits, nil) bitReverse(coeffs) var invBigint big.Int @@ -58,7 +59,8 @@ func lagrangeCoeffsG2(powers []curve.G2Affine, size int) []curve.G2Affine { numCPU := uint64(runtime.NumCPU()) maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) - difFFTG2(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + twiddlesInv, _ := domain.TwiddlesInv() + difFFTG2(coeffs, twiddlesInv, 0, maxSplits, nil) bitReverse(coeffs) var invBigint big.Int diff --git a/backend/groth16/bw6-761/mpcsetup/lagrange.go b/backend/groth16/bw6-761/mpcsetup/lagrange.go index 962efe5ff9..145271ddcd 100644 --- a/backend/groth16/bw6-761/mpcsetup/lagrange.go +++ b/backend/groth16/bw6-761/mpcsetup/lagrange.go @@ -36,7 +36,8 @@ func lagrangeCoeffsG1(powers []curve.G1Affine, size int) []curve.G1Affine { numCPU := uint64(runtime.NumCPU()) maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) - difFFTG1(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + twiddlesInv, _ := domain.TwiddlesInv() + difFFTG1(coeffs, twiddlesInv, 0, maxSplits, nil) bitReverse(coeffs) var invBigint big.Int @@ -58,7 +59,8 @@ func lagrangeCoeffsG2(powers []curve.G2Affine, size int) []curve.G2Affine { numCPU := uint64(runtime.NumCPU()) maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) - difFFTG2(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + twiddlesInv, _ := domain.TwiddlesInv() + difFFTG2(coeffs, twiddlesInv, 0, maxSplits, nil) bitReverse(coeffs) var invBigint big.Int diff --git a/backend/plonk/bls12-377/marshal.go b/backend/plonk/bls12-377/marshal.go index d4c1397b5f..90ad763897 100644 --- a/backend/plonk/bls12-377/marshal.go +++ b/backend/plonk/bls12-377/marshal.go @@ -19,10 +19,6 @@ package plonk import ( curve "github.com/consensys/gnark-crypto/ecc/bls12-377" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - - "errors" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/iop" "github.com/consensys/gnark-crypto/ecc/bls12-377/kzg" "io" ) @@ -116,19 +112,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, withCompression bool) (n int64, err e return } - // fft domains - n2, err := pk.Domain[0].WriteTo(w) - if err != nil { - return - } - n += n2 - - n2, err = pk.Domain[1].WriteTo(w) - if err != nil { - return - } - n += n2 - + var n2 int64 // KZG key if withCompression { n2, err = pk.Kzg.WriteTo(w) @@ -149,35 +133,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, withCompression bool) (n int64, err e } n += n2 - // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) - if len(pk.trace.S) != (3 * int(pk.Domain[0].Cardinality)) { - return n, errors.New("invalid permutation size, expected 3*domain cardinality") - } - - enc := curve.NewEncoder(w) - // note: type Polynomial, which is handled by default binary.Write(...) op and doesn't - // encode the size (nor does it convert from Montgomery to Regular form) - // so we explicitly transmit []fr.Element - toEncode := []interface{}{ - pk.trace.Ql.Coefficients(), - pk.trace.Qr.Coefficients(), - pk.trace.Qm.Coefficients(), - pk.trace.Qo.Coefficients(), - pk.trace.Qk.Coefficients(), - coefficients(pk.trace.Qcp), - pk.trace.S1.Coefficients(), - pk.trace.S2.Coefficients(), - pk.trace.S3.Coefficients(), - pk.trace.S, - } - - for _, v := range toEncode { - if err := enc.Encode(v); err != nil { - return n + enc.BytesWritten(), err - } - } - - return n + enc.BytesWritten(), nil + return n, nil } // ReadFrom reads from binary representation in r into ProvingKey @@ -197,18 +153,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err return n, err } - n2, err, chDomain0 := pk.Domain[0].AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - - n2, err, chDomain1 := pk.Domain[1].AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - + var n2 int64 if withSubgroupChecks { n2, err = pk.Kzg.ReadFrom(r) } else { @@ -224,98 +169,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err n2, err = pk.KzgLagrange.UnsafeReadFrom(r) } n += n2 - if err != nil { - return n, err - } - - pk.trace.S = make([]int64, 3*pk.Domain[0].Cardinality) - - dec := curve.NewDecoder(r) - - var ql, qr, qm, qo, qk, s1, s2, s3 []fr.Element - var qcp [][]fr.Element - - // TODO @gbotrel: this is a bit ugly, we should probably refactor this. - // The order of the variables is important, as it matches the order in which they are - // encoded in the WriteTo(...) method. - - // Note: instead of calling dec.Decode(...) for each of the above variables, - // we call AsyncReadFrom when possible which allows to consume bytes from the reader - // and perform the decoding in parallel - - type v struct { - data *fr.Vector - chErr chan error - } - - vectors := make([]v, 8) - vectors[0] = v{data: (*fr.Vector)(&ql)} - vectors[1] = v{data: (*fr.Vector)(&qr)} - vectors[2] = v{data: (*fr.Vector)(&qm)} - vectors[3] = v{data: (*fr.Vector)(&qo)} - vectors[4] = v{data: (*fr.Vector)(&qk)} - vectors[5] = v{data: (*fr.Vector)(&s1)} - vectors[6] = v{data: (*fr.Vector)(&s2)} - vectors[7] = v{data: (*fr.Vector)(&s3)} - - // read ql, qr, qm, qo, qk - for i := 0; i < 5; i++ { - n2, err, ch := vectors[i].data.AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - vectors[i].chErr = ch - } - - // read qcp - if err := dec.Decode(&qcp); err != nil { - return n + dec.BytesRead(), err - } - - // read lqk, s1, s2, s3 - for i := 5; i < 8; i++ { - n2, err, ch := vectors[i].data.AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - vectors[i].chErr = ch - } - - // read pk.Trace.S - if err := dec.Decode(&pk.trace.S); err != nil { - return n + dec.BytesRead(), err - } - - // wait for all AsyncReadFrom(...) to complete - for i := range vectors { - if err := <-vectors[i].chErr; err != nil { - return n, err - } - } - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - pk.trace.Ql = iop.NewPolynomial(&ql, canReg) - pk.trace.Qr = iop.NewPolynomial(&qr, canReg) - pk.trace.Qm = iop.NewPolynomial(&qm, canReg) - pk.trace.Qo = iop.NewPolynomial(&qo, canReg) - pk.trace.Qk = iop.NewPolynomial(&qk, canReg) - pk.trace.S1 = iop.NewPolynomial(&s1, canReg) - pk.trace.S2 = iop.NewPolynomial(&s2, canReg) - pk.trace.S3 = iop.NewPolynomial(&s3, canReg) - - pk.trace.Qcp = make([]*iop.Polynomial, len(qcp)) - for i := range qcp { - pk.trace.Qcp[i] = iop.NewPolynomial(&qcp[i], canReg) - } - - // wait for FFT to be precomputed - <-chDomain0 - <-chDomain1 - - return n + dec.BytesRead(), nil - + return n, err } // WriteTo writes binary encoding of VerifyingKey to w diff --git a/backend/plonk/bls12-377/marshal_test.go b/backend/plonk/bls12-377/marshal_test.go index 7618d455f4..6179b0527f 100644 --- a/backend/plonk/bls12-377/marshal_test.go +++ b/backend/plonk/bls12-377/marshal_test.go @@ -20,9 +20,6 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bls12-377" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/iop" "github.com/consensys/gnark/io" "math/big" "math/rand" @@ -60,8 +57,6 @@ func (pk *ProvingKey) randomize() { var vk VerifyingKey vk.randomize() pk.Vk = &vk - pk.Domain[0] = *fft.NewDomain(32) - pk.Domain[1] = *fft.NewDomain(4 * 32) pk.Kzg.G1 = make([]curve.G1Affine, 32) pk.KzgLagrange.G1 = make([]curve.G1Affine, 32) @@ -70,36 +65,6 @@ func (pk *ProvingKey) randomize() { pk.KzgLagrange.G1[i] = randomG1Point() } - n := int(pk.Domain[0].Cardinality) - ql := randomScalars(n) - qr := randomScalars(n) - qm := randomScalars(n) - qo := randomScalars(n) - qk := randomScalars(n) - s1 := randomScalars(n) - s2 := randomScalars(n) - s3 := randomScalars(n) - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - pk.trace.Ql = iop.NewPolynomial(&ql, canReg) - pk.trace.Qr = iop.NewPolynomial(&qr, canReg) - pk.trace.Qm = iop.NewPolynomial(&qm, canReg) - pk.trace.Qo = iop.NewPolynomial(&qo, canReg) - pk.trace.Qk = iop.NewPolynomial(&qk, canReg) - pk.trace.S1 = iop.NewPolynomial(&s1, canReg) - pk.trace.S2 = iop.NewPolynomial(&s2, canReg) - pk.trace.S3 = iop.NewPolynomial(&s3, canReg) - - pk.trace.Qcp = make([]*iop.Polynomial, rand.Intn(4)) //#nosec G404 weak rng is fine here - for i := range pk.trace.Qcp { - qcp := randomScalars(rand.Intn(n / 4)) //#nosec G404 weak rng is fine here - pk.trace.Qcp[i] = iop.NewPolynomial(&qcp, canReg) - } - - pk.trace.S = make([]int64, 3*pk.Domain[0].Cardinality) - pk.trace.S[0] = -12 - pk.trace.S[len(pk.trace.S)-1] = 8888 - } func (vk *VerifyingKey) randomize() { diff --git a/backend/plonk/bls12-377/prove.go b/backend/plonk/bls12-377/prove.go index 6a9c269860..b720e79af9 100644 --- a/backend/plonk/bls12-377/prove.go +++ b/backend/plonk/bls12-377/prove.go @@ -138,9 +138,6 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts // solve constraints g.Go(instance.solveConstraints) - // compute numerator data - g.Go(instance.initComputeNumerator) - // complete qk g.Go(instance.completeQk) @@ -211,9 +208,6 @@ type instance struct { // challenges gamma, beta, alpha, zeta fr.Element - // compute numerator data - cres, twiddles0, cosetTableRev, twiddlesRev []fr.Element - // channel to wait for the steps chLRO, chQk, @@ -224,8 +218,11 @@ type instance struct { chZOpening, chLinearizedPolynomial, chFoldedH, - chNumeratorInit, chGammaBeta chan struct{} + + domain0, domain1 *fft.Domain + + trace *Trace } func newInstance(ctx context.Context, spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts *backend.ProverConfig) (*instance, error) { @@ -253,43 +250,31 @@ func newInstance(ctx context.Context, spr *cs.SparseR1CS, pk *ProvingKey, fullWi chLinearizedPolynomial: make(chan struct{}, 1), chFoldedH: make(chan struct{}, 1), chRestoreLRO: make(chan struct{}, 1), - chNumeratorInit: make(chan struct{}, 1), } s.initBSB22Commitments() s.setupGKRHints() s.x = make([]*iop.Polynomial, id_Qci+2*len(s.commitmentInfo)) - return &s, nil -} + // init fft domains + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints + s.domain0 = fft.NewDomain(sizeSystem) -func (s *instance) initComputeNumerator() error { - n := s.pk.Domain[0].Cardinality - s.cres = make([]fr.Element, s.pk.Domain[1].Cardinality) - s.twiddles0 = make([]fr.Element, n) - if n == 1 { - // edge case - s.twiddles0[0].SetOne() + // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, + // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases + // except when n<6. + if sizeSystem < 6 { + s.domain1 = fft.NewDomain(8*sizeSystem, fft.WithoutPrecompute()) } else { - copy(s.twiddles0, s.pk.Domain[0].Twiddles[0]) - for i := len(s.pk.Domain[0].Twiddles[0]); i < len(s.twiddles0); i++ { - s.twiddles0[i].Mul(&s.twiddles0[i-1], &s.twiddles0[1]) - } + s.domain1 = fft.NewDomain(4*sizeSystem, fft.WithoutPrecompute()) } + // TODO @gbotrel domain1 is used for only 1 FFT --> precomputing the twiddles + // and storing them in memory is costly given its size. --> do a FFT on the fly - cosetTable := s.pk.Domain[0].CosetTable - twiddles := s.pk.Domain[1].Twiddles[0][:n] - - s.cosetTableRev = make([]fr.Element, len(cosetTable)) - copy(s.cosetTableRev, cosetTable) - fft.BitReverse(s.cosetTableRev) - - s.twiddlesRev = make([]fr.Element, len(twiddles)) - copy(s.twiddlesRev, twiddles) - fft.BitReverse(s.twiddlesRev) - - close(s.chNumeratorInit) + // build trace + s.trace = NewTrace(spr, s.domain0) - return nil + return &s, nil } func (s *instance) initBlindingPolynomials() error { @@ -321,7 +306,7 @@ func (s *instance) bsb22Hint(_ *big.Int, ins, outs []*big.Int) error { res := &s.commitmentVal[commDepth] commitmentInfo := s.spr.CommitmentInfo.(constraint.PlonkCommitments)[commDepth] - committedValues := make([]fr.Element, s.pk.Domain[0].Cardinality) + committedValues := make([]fr.Element, s.domain0.Cardinality) offset := s.spr.GetNbPublicVariables() for i := range ins { committedValues[offset+commitmentInfo.Committed[i]].SetBigInt(ins[i]) @@ -336,7 +321,6 @@ func (s *instance) bsb22Hint(_ *big.Int, ins, outs []*big.Int) error { if s.proof.Bsb22Commitments[commDepth], err = kzg.Commit(s.cCommitments[commDepth].Coefficients(), s.pk.KzgLagrange); err != nil { return err } - s.cCommitments[commDepth].ToCanonical(&s.pk.Domain[0]).ToRegular() s.htfFunc.Write(s.proof.Bsb22Commitments[commDepth].Marshal()) hashBts := s.htfFunc.Sum(nil) @@ -395,7 +379,7 @@ func (s *instance) solveConstraints() error { } func (s *instance) completeQk() error { - qk := s.pk.trace.Qk.Clone().ToLagrange(&s.pk.Domain[0]).ToRegular() + qk := s.trace.Qk.Clone() qkCoeffs := qk.Coefficients() wWitness, ok := s.fullWitness.Vector().(fr.Vector) @@ -494,7 +478,7 @@ func (s *instance) commitToPolyAndBlinding(p, b *iop.Polynomial) (commit curve.G commit, err = kzg.Commit(p.Coefficients(), s.pk.KzgLagrange) // we add in the blinding contribution - n := int(s.pk.Domain[0].Cardinality) + n := int(s.domain0.Cardinality) cb := commitBlindingFactor(n, b, s.pk.Kzg) commit.Add(&commit, &cb) @@ -518,20 +502,19 @@ func (s *instance) deriveZeta() (err error) { // evaluateConstraints computes H func (s *instance) evaluateConstraints() (err error) { - // clone polys from the proving key. - s.x[id_Ql] = s.pk.trace.Ql.Clone() - s.x[id_Qr] = s.pk.trace.Qr.Clone() - s.x[id_Qm] = s.pk.trace.Qm.Clone() - s.x[id_Qo] = s.pk.trace.Qo.Clone() - s.x[id_S1] = s.pk.trace.S1.Clone() - s.x[id_S2] = s.pk.trace.S2.Clone() - s.x[id_S3] = s.pk.trace.S3.Clone() + s.x[id_Ql] = s.trace.Ql + s.x[id_Qr] = s.trace.Qr + s.x[id_Qm] = s.trace.Qm + s.x[id_Qo] = s.trace.Qo + s.x[id_S1] = s.trace.S1 + s.x[id_S2] = s.trace.S2 + s.x[id_S3] = s.trace.S3 for i := 0; i < len(s.commitmentInfo); i++ { - s.x[id_Qci+2*i] = s.pk.trace.Qcp[i].Clone() + s.x[id_Qci+2*i] = s.trace.Qcp[i] } - n := s.pk.Domain[0].Cardinality + n := s.domain0.Cardinality lone := make([]fr.Element, n) lone[0].SetOne() @@ -543,7 +526,7 @@ func (s *instance) evaluateConstraints() (err error) { } for i := 0; i < len(s.commitmentInfo); i++ { - s.x[id_Qci+2*i+1] = s.cCommitments[i].Clone() + s.x[id_Qci+2*i+1] = s.cCommitments[i] } // wait for Z to be committed or context done @@ -571,7 +554,7 @@ func (s *instance) evaluateConstraints() (err error) { return err } - s.h, err = divideByXMinusOne(numerator, [2]*fft.Domain{&s.pk.Domain[0], &s.pk.Domain[1]}) + s.h, err = divideByXMinusOne(numerator, [2]*fft.Domain{s.domain0, s.domain1}) if err != nil { return err } @@ -613,11 +596,11 @@ func (s *instance) buildRatioCopyConstraint() (err error) { s.x[id_R], s.x[id_O], }, - s.pk.trace.S, + s.trace.S, s.beta, s.gamma, iop.Form{Basis: iop.Lagrange, Layout: iop.Regular}, - &s.pk.Domain[0], + s.domain0, ) if err != nil { return err @@ -652,17 +635,17 @@ func (s *instance) openZ() (err error) { } func (s *instance) h1() []fr.Element { - h1 := s.h.Coefficients()[:s.pk.Domain[0].Cardinality+2] + h1 := s.h.Coefficients()[:s.domain0.Cardinality+2] return h1 } func (s *instance) h2() []fr.Element { - h2 := s.h.Coefficients()[s.pk.Domain[0].Cardinality+2 : 2*(s.pk.Domain[0].Cardinality+2)] + h2 := s.h.Coefficients()[s.domain0.Cardinality+2 : 2*(s.domain0.Cardinality+2)] return h2 } func (s *instance) h3() []fr.Element { - h3 := s.h.Coefficients()[2*(s.pk.Domain[0].Cardinality+2) : 3*(s.pk.Domain[0].Cardinality+2)] + h3 := s.h.Coefficients()[2*(s.domain0.Cardinality+2) : 3*(s.domain0.Cardinality+2)] return h3 } @@ -675,7 +658,7 @@ func (s *instance) foldH() error { case <-s.chH: } var n big.Int - n.SetUint64(s.pk.Domain[0].Cardinality + 2) + n.SetUint64(s.domain0.Cardinality + 2) var zetaPowerNplusTwo fr.Element zetaPowerNplusTwo.Exp(s.zeta, &n) @@ -691,7 +674,7 @@ func (s *instance) foldH() error { h2 := s.h2() s.foldedH = s.h3() - for i := 0; i < int(s.pk.Domain[0].Cardinality)+2; i++ { + for i := 0; i < int(s.domain0.Cardinality)+2; i++ { s.foldedH[i]. Mul(&s.foldedH[i], &zetaPowerNplusTwo). Add(&s.foldedH[i], &h2[i]). @@ -720,7 +703,7 @@ func (s *instance) computeLinearizedPolynomial() error { for i := 0; i < len(s.commitmentInfo); i++ { go func(i int) { - qcpzeta[i] = s.pk.trace.Qcp[i].Evaluate(s.zeta) + qcpzeta[i] = s.trace.Qcp[i].Evaluate(s.zeta) wg.Done() }(i) } @@ -750,7 +733,7 @@ func (s *instance) computeLinearizedPolynomial() error { wg.Wait() - s.linearizedPolynomial = computeLinearizedPolynomial( + s.linearizedPolynomial = s.innerComputeLinearizedPoly( blzeta, brzeta, bozeta, @@ -775,9 +758,6 @@ func (s *instance) computeLinearizedPolynomial() error { } func (s *instance) batchOpening() error { - polysQcp := coefficients(s.pk.trace.Qcp) - polysToOpen := make([][]fr.Element, 7+len(polysQcp)) - copy(polysToOpen[7:], polysQcp) // wait for LRO to be committed (or ctx.Done()) select { @@ -800,13 +780,17 @@ func (s *instance) batchOpening() error { case <-s.chLinearizedPolynomial: } + polysQcp := coefficients(s.trace.Qcp) + polysToOpen := make([][]fr.Element, 7+len(polysQcp)) + copy(polysToOpen[7:], polysQcp) + polysToOpen[0] = s.foldedH polysToOpen[1] = s.linearizedPolynomial polysToOpen[2] = getBlindedCoefficients(s.x[id_L], s.bp[id_Bl]) polysToOpen[3] = getBlindedCoefficients(s.x[id_R], s.bp[id_Br]) polysToOpen[4] = getBlindedCoefficients(s.x[id_O], s.bp[id_Bo]) - polysToOpen[5] = s.pk.trace.S1.Coefficients() - polysToOpen[6] = s.pk.trace.S2.Coefficients() + polysToOpen[5] = s.trace.S1.Coefficients() + polysToOpen[6] = s.trace.S2.Coefficients() digestsToOpen := make([]curve.G1Affine, len(s.pk.Vk.Qcp)+7) copy(digestsToOpen[7:], s.pk.Vk.Qcp) @@ -835,6 +819,23 @@ func (s *instance) batchOpening() error { // evaluate the full set of constraints, all polynomials in x are back in // canonical regular form at the end func (s *instance) computeNumerator() (*iop.Polynomial, error) { + // init vectors that are used multiple times throughout the computation + n := s.domain0.Cardinality + twiddles0 := make([]fr.Element, n) + if n == 1 { + // edge case + twiddles0[0].SetOne() + } else { + twiddles, err := s.domain0.Twiddles() + if err != nil { + return nil, err + } + copy(twiddles0, twiddles[0]) + w := twiddles0[1] + for i := len(twiddles[0]); i < len(twiddles0); i++ { + twiddles0[i].Mul(&twiddles0[i-1], &w) + } + } // wait for chQk to be closed (or ctx.Done()) select { @@ -843,8 +844,6 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { case <-s.chQk: } - n := s.pk.Domain[0].Cardinality - nbBsbGates := (len(s.x) - id_Qci + 1) >> 1 gateConstraint := func(u ...fr.Element) fr.Element { @@ -867,7 +866,7 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { } var cs, css fr.Element - cs.Set(&s.pk.Domain[1].FrMultiplicativeGen) + cs.Set(&s.domain1.FrMultiplicativeGen) css.Square(&cs) orderingConstraint := func(u ...fr.Element) fr.Element { @@ -899,11 +898,11 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { return res } - rho := int(s.pk.Domain[1].Cardinality / n) + rho := int(s.domain1.Cardinality / n) shifters := make([]fr.Element, rho) - shifters[0].Set(&s.pk.Domain[1].FrMultiplicativeGen) + shifters[0].Set(&s.domain1.FrMultiplicativeGen) for i := 1; i < rho; i++ { - shifters[i].Set(&s.pk.Domain[1].Generator) + shifters[i].Set(&s.domain1.Generator) } // stores the current coset shifter @@ -914,14 +913,13 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { one.SetOne() bn := big.NewInt(int64(n)) - // wait for init go routine - <-s.chNumeratorInit - - cosetTable := s.pk.Domain[0].CosetTable - twiddles := s.pk.Domain[1].Twiddles[0][:n] + cosetTable, err := s.domain0.CosetTable() + if err != nil { + return nil, err + } // init the result polynomial & buffer - cres := s.cres + cres := make([]fr.Element, s.domain1.Cardinality) buf := make([]fr.Element, n) var wgBuf sync.WaitGroup @@ -933,17 +931,17 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { // blind L, R, O, Z, ZS var y fr.Element - y = s.bp[id_Bl].Evaluate(s.twiddles0[i]) + y = s.bp[id_Bl].Evaluate(twiddles0[i]) u[id_L].Add(&u[id_L], &y) - y = s.bp[id_Br].Evaluate(s.twiddles0[i]) + y = s.bp[id_Br].Evaluate(twiddles0[i]) u[id_R].Add(&u[id_R], &y) - y = s.bp[id_Bo].Evaluate(s.twiddles0[i]) + y = s.bp[id_Bo].Evaluate(twiddles0[i]) u[id_O].Add(&u[id_O], &y) - y = s.bp[id_Bz].Evaluate(s.twiddles0[i]) + y = s.bp[id_Bz].Evaluate(twiddles0[i]) u[id_Z].Add(&u[id_Z], &y) // ZS is shifted by 1; need to get correct twiddle - y = s.bp[id_Bz].Evaluate(s.twiddles0[(i+1)%int(n)]) + y = s.bp[id_Bz].Evaluate(twiddles0[(i+1)%int(n)]) u[id_ZS].Add(&u[id_ZS], &y) a := gateConstraint(u...) @@ -953,28 +951,15 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { return c } - // select the correct scaling vector to scale by shifter[i] - selectScalingVector := func(i int, l iop.Layout) []fr.Element { - var w []fr.Element - if i == 0 { - if l == iop.Regular { - w = cosetTable - } else { - w = s.cosetTableRev - } - } else { - if l == iop.Regular { - w = twiddles - } else { - w = s.twiddlesRev - } - } - return w - } + // for the first iteration, the scalingVector is the coset table + scalingVector := cosetTable + scalingVectorRev := make([]fr.Element, len(cosetTable)) + copy(scalingVectorRev, cosetTable) + fft.BitReverse(scalingVectorRev) // pre-computed to compute the bit reverse index // of the result polynomial - m := uint64(s.pk.Domain[1].Cardinality) + m := uint64(s.domain1.Cardinality) mm := uint64(64 - bits.TrailingZeros64(m)) for i := 0; i < rho; i++ { @@ -991,6 +976,17 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { acc.Mul(&acc, &shifters[i]) } } + if i == 1 { + // we have to update the scalingVector; instead of scaling by + // cosets we scale by the twiddles of the large domain. + w := s.domain1.Generator + scalingVector = make([]fr.Element, n) + fft.BuildExpTable(w, scalingVector) + + // reuse memory + copy(scalingVectorRev, scalingVector) + fft.BitReverse(scalingVectorRev) + } // we do **a lot** of FFT here, but on the small domain. // note that for all the polynomials in the proving key @@ -1000,10 +996,15 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { batchApply(s.x, func(p *iop.Polynomial) { nbTasks := calculateNbTasks(len(s.x)-1) * 2 // shift polynomials to be in the correct coset - p.ToCanonical(&s.pk.Domain[0], nbTasks) + p.ToCanonical(s.domain0, nbTasks) // scale by shifter[i] - w := selectScalingVector(i, p.Layout) + var w []fr.Element + if p.Layout == iop.Regular { + w = scalingVector + } else { + w = scalingVectorRev + } cp := p.Coefficients() utils.Parallelize(len(cp), func(start, end int) { @@ -1013,7 +1014,7 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { }, nbTasks) // fft in the correct coset - p.ToLagrange(&s.pk.Domain[0], nbTasks).ToRegular() + p.ToLagrange(s.domain0, nbTasks).ToRegular() }) wgBuf.Wait() @@ -1046,9 +1047,10 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { // scale everything back go func() { - for i := id_ZS; i < len(s.x); i++ { - s.x[i] = nil - } + s.x[id_ID] = nil + s.x[id_LOne] = nil + s.x[id_ZS] = nil + s.x[id_Qk] = nil var cs fr.Element cs.Set(&shifters[0]) @@ -1057,8 +1059,11 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { } cs.Inverse(&cs) - batchApply(s.x[:id_ZS], func(p *iop.Polynomial) { - p.ToCanonical(&s.pk.Domain[0], 8).ToRegular() + batchApply(s.x, func(p *iop.Polynomial) { + if p == nil { + return + } + p.ToCanonical(s.domain0, 8).ToRegular() scalePowers(p, cs) }) @@ -1275,7 +1280,7 @@ func evaluateXnMinusOneDomainBigCoset(domains [2]*fft.Domain) []fr.Element { return res } -// computeLinearizedPolynomial computes the linearized polynomial in canonical basis. +// innerComputeLinearizedPoly computes the linearized polynomial in canonical basis. // The purpose is to commit and open all in one ql, qr, qm, qo, qk. // * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta // * z is the permutation polynomial, zu is Z(μX), the shifted version of Z @@ -1286,8 +1291,8 @@ func evaluateXnMinusOneDomainBigCoset(domains [2]*fft.Domain) []fr.Element { // α²*L₁(ζ)*Z(X) // + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) // + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) -func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, qcpZeta, blindedZCanonical []fr.Element, pi2Canonical [][]fr.Element, pk *ProvingKey) []fr.Element { - +func (s *instance) innerComputeLinearizedPoly(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, qcpZeta, blindedZCanonical []fr.Element, pi2Canonical [][]fr.Element, pk *ProvingKey) []fr.Element { + // TODO @gbotrel rename // first part: individual constraints var rl fr.Element rl.Mul(&rZeta, &lZeta) @@ -1297,12 +1302,12 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, var s1, s2 fr.Element chS1 := make(chan struct{}, 1) go func() { - s1 = pk.trace.S1.Evaluate(zeta) // s1(ζ) + s1 = s.trace.S1.Evaluate(zeta) // s1(ζ) s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) close(chS1) }() // ps2 := iop.NewPolynomial(&pk.S2Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - tmp := pk.trace.S2.Evaluate(zeta) // s2(ζ) + tmp := s.trace.S2.Evaluate(zeta) // s2(ζ) tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) @@ -1321,7 +1326,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, // third part L₁(ζ)*α²*Z var lagrangeZeta, one, den, frNbElmt fr.Element one.SetOne() - nbElmt := int64(pk.Domain[0].Cardinality) + nbElmt := int64(s.domain0.Cardinality) lagrangeZeta.Set(&zeta). Exp(lagrangeZeta, big.NewInt(nbElmt)). Sub(&lagrangeZeta, &one) @@ -1331,17 +1336,19 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) Mul(&lagrangeZeta, &alpha). Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) + Mul(&lagrangeZeta, &s.domain0.CardinalityInv) // (1/n)*α²*L₁(ζ) + + s3canonical := s.trace.S3.Coefficients() - s3canonical := pk.trace.S3.Coefficients() + s.trace.Qk.ToCanonical(s.domain0).ToRegular() utils.Parallelize(len(blindedZCanonical), func(start, end int) { - cql := pk.trace.Ql.Coefficients() - cqr := pk.trace.Qr.Coefficients() - cqm := pk.trace.Qm.Coefficients() - cqo := pk.trace.Qo.Coefficients() - cqk := pk.trace.Qk.Coefficients() + cql := s.trace.Ql.Coefficients() + cqr := s.trace.Qr.Coefficients() + cqm := s.trace.Qm.Coefficients() + cqo := s.trace.Qo.Coefficients() + cqk := s.trace.Qk.Coefficients() var t, t0, t1 fr.Element diff --git a/backend/plonk/bls12-377/setup.go b/backend/plonk/bls12-377/setup.go index e45d620c18..0b9158e9a5 100644 --- a/backend/plonk/bls12-377/setup.go +++ b/backend/plonk/bls12-377/setup.go @@ -78,34 +78,14 @@ type Trace struct { S []int64 } -// ProvingKey stores the data needed to generate a proof: -// * the commitment scheme -// * ql, prepended with as many ones as they are public inputs -// * qr, qm, qo prepended with as many zeroes as there are public inputs. -// * qk, prepended with as many zeroes as public inputs, to be completed by the prover -// with the list of public inputs. -// * sigma_1, sigma_2, sigma_3 in both basis -// * the copy constraint permutation +// ProvingKey stores the data needed to generate a proof type ProvingKey struct { - // stores ql, qr, qm, qo, qk (-> to be completed by the prover) - // and s1, s2, s3. They are set in canonical basis before generating the proof, they will be used - // for computing the opening proofs (hence the canonical form). The canonical version - // of qk incomplete is used in the linearisation polynomial. - // The polynomials in trace are in canonical basis. - trace Trace - Kzg, KzgLagrange kzg.ProvingKey // Verifying Key is embedded into the proving key (needed by Prove) Vk *VerifyingKey - - // Domains used for the FFTs. - // Domain[0] = small Domain - // Domain[1] = big Domain - Domain [2]fft.Domain } -// TODO modify the signature to receive the SRS in Lagrange form (optional argument ?) func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *VerifyingKey, error) { var pk ProvingKey @@ -114,26 +94,26 @@ func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *Verifyin vk.CommitmentConstraintIndexes = internal.IntSliceToUint64Slice(spr.CommitmentInfo.CommitmentIndexes()) // step 0: set the fft domains - pk.initDomains(spr) - if pk.Domain[0].Cardinality < 2 { + domain := initFFTDomain(spr) + if domain.Cardinality < 2 { return nil, nil, fmt.Errorf("circuit has only %d constraints; unsupported by the current implementation", spr.GetNbConstraints()) } // check the size of the kzg srs. - if len(srs.Pk.G1) < (int(pk.Domain[0].Cardinality) + 3) { // + 3 for the kzg.Open of blinded poly - return nil, nil, fmt.Errorf("kzg srs is too small: got %d, need %d", len(srs.Pk.G1), pk.Domain[0].Cardinality+3) + if len(srs.Pk.G1) < (int(domain.Cardinality) + 3) { // + 3 for the kzg.Open of blinded poly + return nil, nil, fmt.Errorf("kzg srs is too small: got %d, need %d", len(srs.Pk.G1), domain.Cardinality+3) } // same for the lagrange form - if len(srsLagrange.Pk.G1) != int(pk.Domain[0].Cardinality) { - return nil, nil, fmt.Errorf("kzg srs lagrange is too small: got %d, need %d", len(srsLagrange.Pk.G1), pk.Domain[0].Cardinality) + if len(srsLagrange.Pk.G1) != int(domain.Cardinality) { + return nil, nil, fmt.Errorf("kzg srs lagrange is too small: got %d, need %d", len(srsLagrange.Pk.G1), domain.Cardinality) } // step 1: set the verifying key - pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) - vk.Size = pk.Domain[0].Cardinality + vk.CosetShift.Set(&domain.FrMultiplicativeGen) + vk.Size = domain.Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.Domain[0].Generator) + vk.Generator.Set(&domain.Generator) vk.NbPublicVariables = uint64(len(spr.Public)) pk.Kzg.G1 = srs.Pk.G1[:int(vk.Size)+3] @@ -141,22 +121,15 @@ func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *Verifyin vk.Kzg = srs.Vk // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis - BuildTrace(spr, &pk.trace) - // step 3: build the permutation and build the polynomials S1, S2, S3 to encode the permutation. // Note: at this stage, the permutation takes in account the placeholders - nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - buildPermutation(spr, &pk.trace, nbVariables) - s := computePermutationPolynomials(&pk.trace, &pk.Domain[0]) - pk.trace.S1 = s[0] - pk.trace.S2 = s[1] - pk.trace.S3 = s[2] + trace := NewTrace(spr, domain) // step 4: commit to s1, s2, s3, ql, qr, qm, qo, and (the incomplete version of) qk. // All the above polynomials are expressed in canonical basis afterwards. This is why // we save lqk before, because the prover needs to complete it in Lagrange form, and // then express it on the Lagrange coset basis. - if err := commitTrace(&pk.trace, &pk); err != nil { + if err := vk.commitTrace(trace, domain, pk.KzgLagrange); err != nil { return nil, nil, err } @@ -173,9 +146,13 @@ func (pk *ProvingKey) VerifyingKey() interface{} { return pk.Vk } -// BuildTrace fills the constant columns ql, qr, qm, qo, qk from the sparser1cs. -// Size is the size of the system that is nb_constraints+nb_public_variables -func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { +// NewTrace returns a new Trace object from the constraint system. +// It fills the constant columns ql, qr, qm, qo, qk, and qcp with the +// coefficients of the constraints. +// Size is the size of the system that is next power of 2 (nb_constraints+nb_public_variables) +// The permutation is also computed and stored in the Trace. +func NewTrace(spr *cs.SparseR1CS, domain *fft.Domain) *Trace { + var trace Trace nbConstraints := spr.GetNbConstraints() sizeSystem := uint64(nbConstraints + len(spr.Public)) @@ -211,85 +188,75 @@ func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} - pt.Ql = iop.NewPolynomial(&ql, lagReg) - pt.Qr = iop.NewPolynomial(&qr, lagReg) - pt.Qm = iop.NewPolynomial(&qm, lagReg) - pt.Qo = iop.NewPolynomial(&qo, lagReg) - pt.Qk = iop.NewPolynomial(&qk, lagReg) - pt.Qcp = make([]*iop.Polynomial, len(qcp)) + trace.Ql = iop.NewPolynomial(&ql, lagReg) + trace.Qr = iop.NewPolynomial(&qr, lagReg) + trace.Qm = iop.NewPolynomial(&qm, lagReg) + trace.Qo = iop.NewPolynomial(&qo, lagReg) + trace.Qk = iop.NewPolynomial(&qk, lagReg) + trace.Qcp = make([]*iop.Polynomial, len(qcp)) for i := range commitmentInfo { qcp[i] = make([]fr.Element, size) for _, committed := range commitmentInfo[i].Committed { qcp[i][offset+committed].SetOne() } - pt.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) + trace.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) } + + // build the permutation and build the polynomials S1, S2, S3 to encode the permutation. + // Note: at this stage, the permutation takes in account the placeholders + nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + buildPermutation(spr, &trace, nbVariables) + s := computePermutationPolynomials(&trace, domain) + trace.S1 = s[0] + trace.S2 = s[1] + trace.S3 = s[2] + + return &trace } // commitTrace commits to every polynomial in the trace, and put // the commitments int the verifying key. -func commitTrace(trace *Trace, pk *ProvingKey) error { - - trace.Ql.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qr.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qm.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qo.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qk.ToCanonical(&pk.Domain[0]).ToRegular() // -> qk is not complete - trace.S1.ToCanonical(&pk.Domain[0]).ToRegular() - trace.S2.ToCanonical(&pk.Domain[0]).ToRegular() - trace.S3.ToCanonical(&pk.Domain[0]).ToRegular() +func (vk *VerifyingKey) commitTrace(trace *Trace, domain *fft.Domain, srsPk kzg.ProvingKey) error { var err error - pk.Vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) + vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) for i := range trace.Qcp { - trace.Qcp[i].ToCanonical(&pk.Domain[0]).ToRegular() - if pk.Vk.Qcp[i], err = kzg.Commit(pk.trace.Qcp[i].Coefficients(), pk.Kzg); err != nil { + if vk.Qcp[i], err = kzg.Commit(trace.Qcp[i].Coefficients(), srsPk); err != nil { return err } } - if pk.Vk.Ql, err = kzg.Commit(pk.trace.Ql.Coefficients(), pk.Kzg); err != nil { + if vk.Ql, err = kzg.Commit(trace.Ql.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qr, err = kzg.Commit(pk.trace.Qr.Coefficients(), pk.Kzg); err != nil { + if vk.Qr, err = kzg.Commit(trace.Qr.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qm, err = kzg.Commit(pk.trace.Qm.Coefficients(), pk.Kzg); err != nil { + if vk.Qm, err = kzg.Commit(trace.Qm.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qo, err = kzg.Commit(pk.trace.Qo.Coefficients(), pk.Kzg); err != nil { + if vk.Qo, err = kzg.Commit(trace.Qo.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qk, err = kzg.Commit(pk.trace.Qk.Coefficients(), pk.Kzg); err != nil { + if vk.Qk, err = kzg.Commit(trace.Qk.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.S[0], err = kzg.Commit(pk.trace.S1.Coefficients(), pk.Kzg); err != nil { + if vk.S[0], err = kzg.Commit(trace.S1.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.S[1], err = kzg.Commit(pk.trace.S2.Coefficients(), pk.Kzg); err != nil { + if vk.S[1], err = kzg.Commit(trace.S2.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.S[2], err = kzg.Commit(pk.trace.S3.Coefficients(), pk.Kzg); err != nil { + if vk.S[2], err = kzg.Commit(trace.S3.Coefficients(), srsPk); err != nil { return err } return nil } -func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { - +func initFFTDomain(spr *cs.SparseR1CS) *fft.Domain { nbConstraints := spr.GetNbConstraints() sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints - pk.Domain[0] = *fft.NewDomain(sizeSystem) - - // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, - // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases - // except when n<6. - if sizeSystem < 6 { - pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) - } else { - pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) - } - + return fft.NewDomain(sizeSystem, fft.WithoutPrecompute()) } // buildPermutation builds the Permutation associated with a circuit. @@ -304,10 +271,10 @@ func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { // The permutation is encoded as a slice s of size 3*size(l), where the // i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab // like this: for i in tab: tab[i] = tab[permutation[i]] -func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { +func buildPermutation(spr *cs.SparseR1CS, trace *Trace, nbVariables int) { // nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - sizeSolution := len(pt.Ql.Coefficients()) + sizeSolution := len(trace.Ql.Coefficients()) sizePermutation := 3 * sizeSolution // init permutation @@ -357,13 +324,13 @@ func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { } } - pt.S = permutation + trace.S = permutation } // computePermutationPolynomials computes the LDE (Lagrange basis) of the permutation. // We let the permutation act on || u || u^{2}, split the result in 3 parts, // and interpolate each of the 3 parts on . -func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polynomial { +func computePermutationPolynomials(trace *Trace, domain *fft.Domain) [3]*iop.Polynomial { nbElmts := int(domain.Cardinality) @@ -377,9 +344,9 @@ func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polyno s2Canonical := make([]fr.Element, nbElmts) s3Canonical := make([]fr.Element, nbElmts) for i := 0; i < nbElmts; i++ { - s1Canonical[i].Set(&evaluationIDSmallDomain[pt.S[i]]) - s2Canonical[i].Set(&evaluationIDSmallDomain[pt.S[nbElmts+i]]) - s3Canonical[i].Set(&evaluationIDSmallDomain[pt.S[2*nbElmts+i]]) + s1Canonical[i].Set(&evaluationIDSmallDomain[trace.S[i]]) + s2Canonical[i].Set(&evaluationIDSmallDomain[trace.S[nbElmts+i]]) + s3Canonical[i].Set(&evaluationIDSmallDomain[trace.S[2*nbElmts+i]]) } lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} diff --git a/backend/plonk/bls12-381/marshal.go b/backend/plonk/bls12-381/marshal.go index ad50b7ac50..af019887bd 100644 --- a/backend/plonk/bls12-381/marshal.go +++ b/backend/plonk/bls12-381/marshal.go @@ -19,10 +19,6 @@ package plonk import ( curve "github.com/consensys/gnark-crypto/ecc/bls12-381" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" - - "errors" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/iop" "github.com/consensys/gnark-crypto/ecc/bls12-381/kzg" "io" ) @@ -116,19 +112,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, withCompression bool) (n int64, err e return } - // fft domains - n2, err := pk.Domain[0].WriteTo(w) - if err != nil { - return - } - n += n2 - - n2, err = pk.Domain[1].WriteTo(w) - if err != nil { - return - } - n += n2 - + var n2 int64 // KZG key if withCompression { n2, err = pk.Kzg.WriteTo(w) @@ -149,35 +133,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, withCompression bool) (n int64, err e } n += n2 - // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) - if len(pk.trace.S) != (3 * int(pk.Domain[0].Cardinality)) { - return n, errors.New("invalid permutation size, expected 3*domain cardinality") - } - - enc := curve.NewEncoder(w) - // note: type Polynomial, which is handled by default binary.Write(...) op and doesn't - // encode the size (nor does it convert from Montgomery to Regular form) - // so we explicitly transmit []fr.Element - toEncode := []interface{}{ - pk.trace.Ql.Coefficients(), - pk.trace.Qr.Coefficients(), - pk.trace.Qm.Coefficients(), - pk.trace.Qo.Coefficients(), - pk.trace.Qk.Coefficients(), - coefficients(pk.trace.Qcp), - pk.trace.S1.Coefficients(), - pk.trace.S2.Coefficients(), - pk.trace.S3.Coefficients(), - pk.trace.S, - } - - for _, v := range toEncode { - if err := enc.Encode(v); err != nil { - return n + enc.BytesWritten(), err - } - } - - return n + enc.BytesWritten(), nil + return n, nil } // ReadFrom reads from binary representation in r into ProvingKey @@ -197,18 +153,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err return n, err } - n2, err, chDomain0 := pk.Domain[0].AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - - n2, err, chDomain1 := pk.Domain[1].AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - + var n2 int64 if withSubgroupChecks { n2, err = pk.Kzg.ReadFrom(r) } else { @@ -224,98 +169,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err n2, err = pk.KzgLagrange.UnsafeReadFrom(r) } n += n2 - if err != nil { - return n, err - } - - pk.trace.S = make([]int64, 3*pk.Domain[0].Cardinality) - - dec := curve.NewDecoder(r) - - var ql, qr, qm, qo, qk, s1, s2, s3 []fr.Element - var qcp [][]fr.Element - - // TODO @gbotrel: this is a bit ugly, we should probably refactor this. - // The order of the variables is important, as it matches the order in which they are - // encoded in the WriteTo(...) method. - - // Note: instead of calling dec.Decode(...) for each of the above variables, - // we call AsyncReadFrom when possible which allows to consume bytes from the reader - // and perform the decoding in parallel - - type v struct { - data *fr.Vector - chErr chan error - } - - vectors := make([]v, 8) - vectors[0] = v{data: (*fr.Vector)(&ql)} - vectors[1] = v{data: (*fr.Vector)(&qr)} - vectors[2] = v{data: (*fr.Vector)(&qm)} - vectors[3] = v{data: (*fr.Vector)(&qo)} - vectors[4] = v{data: (*fr.Vector)(&qk)} - vectors[5] = v{data: (*fr.Vector)(&s1)} - vectors[6] = v{data: (*fr.Vector)(&s2)} - vectors[7] = v{data: (*fr.Vector)(&s3)} - - // read ql, qr, qm, qo, qk - for i := 0; i < 5; i++ { - n2, err, ch := vectors[i].data.AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - vectors[i].chErr = ch - } - - // read qcp - if err := dec.Decode(&qcp); err != nil { - return n + dec.BytesRead(), err - } - - // read lqk, s1, s2, s3 - for i := 5; i < 8; i++ { - n2, err, ch := vectors[i].data.AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - vectors[i].chErr = ch - } - - // read pk.Trace.S - if err := dec.Decode(&pk.trace.S); err != nil { - return n + dec.BytesRead(), err - } - - // wait for all AsyncReadFrom(...) to complete - for i := range vectors { - if err := <-vectors[i].chErr; err != nil { - return n, err - } - } - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - pk.trace.Ql = iop.NewPolynomial(&ql, canReg) - pk.trace.Qr = iop.NewPolynomial(&qr, canReg) - pk.trace.Qm = iop.NewPolynomial(&qm, canReg) - pk.trace.Qo = iop.NewPolynomial(&qo, canReg) - pk.trace.Qk = iop.NewPolynomial(&qk, canReg) - pk.trace.S1 = iop.NewPolynomial(&s1, canReg) - pk.trace.S2 = iop.NewPolynomial(&s2, canReg) - pk.trace.S3 = iop.NewPolynomial(&s3, canReg) - - pk.trace.Qcp = make([]*iop.Polynomial, len(qcp)) - for i := range qcp { - pk.trace.Qcp[i] = iop.NewPolynomial(&qcp[i], canReg) - } - - // wait for FFT to be precomputed - <-chDomain0 - <-chDomain1 - - return n + dec.BytesRead(), nil - + return n, err } // WriteTo writes binary encoding of VerifyingKey to w diff --git a/backend/plonk/bls12-381/marshal_test.go b/backend/plonk/bls12-381/marshal_test.go index 3a02d789c1..900ab941f4 100644 --- a/backend/plonk/bls12-381/marshal_test.go +++ b/backend/plonk/bls12-381/marshal_test.go @@ -20,9 +20,6 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bls12-381" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" - - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/iop" "github.com/consensys/gnark/io" "math/big" "math/rand" @@ -60,8 +57,6 @@ func (pk *ProvingKey) randomize() { var vk VerifyingKey vk.randomize() pk.Vk = &vk - pk.Domain[0] = *fft.NewDomain(32) - pk.Domain[1] = *fft.NewDomain(4 * 32) pk.Kzg.G1 = make([]curve.G1Affine, 32) pk.KzgLagrange.G1 = make([]curve.G1Affine, 32) @@ -70,36 +65,6 @@ func (pk *ProvingKey) randomize() { pk.KzgLagrange.G1[i] = randomG1Point() } - n := int(pk.Domain[0].Cardinality) - ql := randomScalars(n) - qr := randomScalars(n) - qm := randomScalars(n) - qo := randomScalars(n) - qk := randomScalars(n) - s1 := randomScalars(n) - s2 := randomScalars(n) - s3 := randomScalars(n) - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - pk.trace.Ql = iop.NewPolynomial(&ql, canReg) - pk.trace.Qr = iop.NewPolynomial(&qr, canReg) - pk.trace.Qm = iop.NewPolynomial(&qm, canReg) - pk.trace.Qo = iop.NewPolynomial(&qo, canReg) - pk.trace.Qk = iop.NewPolynomial(&qk, canReg) - pk.trace.S1 = iop.NewPolynomial(&s1, canReg) - pk.trace.S2 = iop.NewPolynomial(&s2, canReg) - pk.trace.S3 = iop.NewPolynomial(&s3, canReg) - - pk.trace.Qcp = make([]*iop.Polynomial, rand.Intn(4)) //#nosec G404 weak rng is fine here - for i := range pk.trace.Qcp { - qcp := randomScalars(rand.Intn(n / 4)) //#nosec G404 weak rng is fine here - pk.trace.Qcp[i] = iop.NewPolynomial(&qcp, canReg) - } - - pk.trace.S = make([]int64, 3*pk.Domain[0].Cardinality) - pk.trace.S[0] = -12 - pk.trace.S[len(pk.trace.S)-1] = 8888 - } func (vk *VerifyingKey) randomize() { diff --git a/backend/plonk/bls12-381/prove.go b/backend/plonk/bls12-381/prove.go index 0b46e8930e..8786e12df9 100644 --- a/backend/plonk/bls12-381/prove.go +++ b/backend/plonk/bls12-381/prove.go @@ -138,9 +138,6 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts // solve constraints g.Go(instance.solveConstraints) - // compute numerator data - g.Go(instance.initComputeNumerator) - // complete qk g.Go(instance.completeQk) @@ -211,9 +208,6 @@ type instance struct { // challenges gamma, beta, alpha, zeta fr.Element - // compute numerator data - cres, twiddles0, cosetTableRev, twiddlesRev []fr.Element - // channel to wait for the steps chLRO, chQk, @@ -224,8 +218,11 @@ type instance struct { chZOpening, chLinearizedPolynomial, chFoldedH, - chNumeratorInit, chGammaBeta chan struct{} + + domain0, domain1 *fft.Domain + + trace *Trace } func newInstance(ctx context.Context, spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts *backend.ProverConfig) (*instance, error) { @@ -253,43 +250,31 @@ func newInstance(ctx context.Context, spr *cs.SparseR1CS, pk *ProvingKey, fullWi chLinearizedPolynomial: make(chan struct{}, 1), chFoldedH: make(chan struct{}, 1), chRestoreLRO: make(chan struct{}, 1), - chNumeratorInit: make(chan struct{}, 1), } s.initBSB22Commitments() s.setupGKRHints() s.x = make([]*iop.Polynomial, id_Qci+2*len(s.commitmentInfo)) - return &s, nil -} + // init fft domains + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints + s.domain0 = fft.NewDomain(sizeSystem) -func (s *instance) initComputeNumerator() error { - n := s.pk.Domain[0].Cardinality - s.cres = make([]fr.Element, s.pk.Domain[1].Cardinality) - s.twiddles0 = make([]fr.Element, n) - if n == 1 { - // edge case - s.twiddles0[0].SetOne() + // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, + // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases + // except when n<6. + if sizeSystem < 6 { + s.domain1 = fft.NewDomain(8*sizeSystem, fft.WithoutPrecompute()) } else { - copy(s.twiddles0, s.pk.Domain[0].Twiddles[0]) - for i := len(s.pk.Domain[0].Twiddles[0]); i < len(s.twiddles0); i++ { - s.twiddles0[i].Mul(&s.twiddles0[i-1], &s.twiddles0[1]) - } + s.domain1 = fft.NewDomain(4*sizeSystem, fft.WithoutPrecompute()) } + // TODO @gbotrel domain1 is used for only 1 FFT --> precomputing the twiddles + // and storing them in memory is costly given its size. --> do a FFT on the fly - cosetTable := s.pk.Domain[0].CosetTable - twiddles := s.pk.Domain[1].Twiddles[0][:n] - - s.cosetTableRev = make([]fr.Element, len(cosetTable)) - copy(s.cosetTableRev, cosetTable) - fft.BitReverse(s.cosetTableRev) - - s.twiddlesRev = make([]fr.Element, len(twiddles)) - copy(s.twiddlesRev, twiddles) - fft.BitReverse(s.twiddlesRev) - - close(s.chNumeratorInit) + // build trace + s.trace = NewTrace(spr, s.domain0) - return nil + return &s, nil } func (s *instance) initBlindingPolynomials() error { @@ -321,7 +306,7 @@ func (s *instance) bsb22Hint(_ *big.Int, ins, outs []*big.Int) error { res := &s.commitmentVal[commDepth] commitmentInfo := s.spr.CommitmentInfo.(constraint.PlonkCommitments)[commDepth] - committedValues := make([]fr.Element, s.pk.Domain[0].Cardinality) + committedValues := make([]fr.Element, s.domain0.Cardinality) offset := s.spr.GetNbPublicVariables() for i := range ins { committedValues[offset+commitmentInfo.Committed[i]].SetBigInt(ins[i]) @@ -336,7 +321,6 @@ func (s *instance) bsb22Hint(_ *big.Int, ins, outs []*big.Int) error { if s.proof.Bsb22Commitments[commDepth], err = kzg.Commit(s.cCommitments[commDepth].Coefficients(), s.pk.KzgLagrange); err != nil { return err } - s.cCommitments[commDepth].ToCanonical(&s.pk.Domain[0]).ToRegular() s.htfFunc.Write(s.proof.Bsb22Commitments[commDepth].Marshal()) hashBts := s.htfFunc.Sum(nil) @@ -395,7 +379,7 @@ func (s *instance) solveConstraints() error { } func (s *instance) completeQk() error { - qk := s.pk.trace.Qk.Clone().ToLagrange(&s.pk.Domain[0]).ToRegular() + qk := s.trace.Qk.Clone() qkCoeffs := qk.Coefficients() wWitness, ok := s.fullWitness.Vector().(fr.Vector) @@ -494,7 +478,7 @@ func (s *instance) commitToPolyAndBlinding(p, b *iop.Polynomial) (commit curve.G commit, err = kzg.Commit(p.Coefficients(), s.pk.KzgLagrange) // we add in the blinding contribution - n := int(s.pk.Domain[0].Cardinality) + n := int(s.domain0.Cardinality) cb := commitBlindingFactor(n, b, s.pk.Kzg) commit.Add(&commit, &cb) @@ -518,20 +502,19 @@ func (s *instance) deriveZeta() (err error) { // evaluateConstraints computes H func (s *instance) evaluateConstraints() (err error) { - // clone polys from the proving key. - s.x[id_Ql] = s.pk.trace.Ql.Clone() - s.x[id_Qr] = s.pk.trace.Qr.Clone() - s.x[id_Qm] = s.pk.trace.Qm.Clone() - s.x[id_Qo] = s.pk.trace.Qo.Clone() - s.x[id_S1] = s.pk.trace.S1.Clone() - s.x[id_S2] = s.pk.trace.S2.Clone() - s.x[id_S3] = s.pk.trace.S3.Clone() + s.x[id_Ql] = s.trace.Ql + s.x[id_Qr] = s.trace.Qr + s.x[id_Qm] = s.trace.Qm + s.x[id_Qo] = s.trace.Qo + s.x[id_S1] = s.trace.S1 + s.x[id_S2] = s.trace.S2 + s.x[id_S3] = s.trace.S3 for i := 0; i < len(s.commitmentInfo); i++ { - s.x[id_Qci+2*i] = s.pk.trace.Qcp[i].Clone() + s.x[id_Qci+2*i] = s.trace.Qcp[i] } - n := s.pk.Domain[0].Cardinality + n := s.domain0.Cardinality lone := make([]fr.Element, n) lone[0].SetOne() @@ -543,7 +526,7 @@ func (s *instance) evaluateConstraints() (err error) { } for i := 0; i < len(s.commitmentInfo); i++ { - s.x[id_Qci+2*i+1] = s.cCommitments[i].Clone() + s.x[id_Qci+2*i+1] = s.cCommitments[i] } // wait for Z to be committed or context done @@ -571,7 +554,7 @@ func (s *instance) evaluateConstraints() (err error) { return err } - s.h, err = divideByXMinusOne(numerator, [2]*fft.Domain{&s.pk.Domain[0], &s.pk.Domain[1]}) + s.h, err = divideByXMinusOne(numerator, [2]*fft.Domain{s.domain0, s.domain1}) if err != nil { return err } @@ -613,11 +596,11 @@ func (s *instance) buildRatioCopyConstraint() (err error) { s.x[id_R], s.x[id_O], }, - s.pk.trace.S, + s.trace.S, s.beta, s.gamma, iop.Form{Basis: iop.Lagrange, Layout: iop.Regular}, - &s.pk.Domain[0], + s.domain0, ) if err != nil { return err @@ -652,17 +635,17 @@ func (s *instance) openZ() (err error) { } func (s *instance) h1() []fr.Element { - h1 := s.h.Coefficients()[:s.pk.Domain[0].Cardinality+2] + h1 := s.h.Coefficients()[:s.domain0.Cardinality+2] return h1 } func (s *instance) h2() []fr.Element { - h2 := s.h.Coefficients()[s.pk.Domain[0].Cardinality+2 : 2*(s.pk.Domain[0].Cardinality+2)] + h2 := s.h.Coefficients()[s.domain0.Cardinality+2 : 2*(s.domain0.Cardinality+2)] return h2 } func (s *instance) h3() []fr.Element { - h3 := s.h.Coefficients()[2*(s.pk.Domain[0].Cardinality+2) : 3*(s.pk.Domain[0].Cardinality+2)] + h3 := s.h.Coefficients()[2*(s.domain0.Cardinality+2) : 3*(s.domain0.Cardinality+2)] return h3 } @@ -675,7 +658,7 @@ func (s *instance) foldH() error { case <-s.chH: } var n big.Int - n.SetUint64(s.pk.Domain[0].Cardinality + 2) + n.SetUint64(s.domain0.Cardinality + 2) var zetaPowerNplusTwo fr.Element zetaPowerNplusTwo.Exp(s.zeta, &n) @@ -691,7 +674,7 @@ func (s *instance) foldH() error { h2 := s.h2() s.foldedH = s.h3() - for i := 0; i < int(s.pk.Domain[0].Cardinality)+2; i++ { + for i := 0; i < int(s.domain0.Cardinality)+2; i++ { s.foldedH[i]. Mul(&s.foldedH[i], &zetaPowerNplusTwo). Add(&s.foldedH[i], &h2[i]). @@ -720,7 +703,7 @@ func (s *instance) computeLinearizedPolynomial() error { for i := 0; i < len(s.commitmentInfo); i++ { go func(i int) { - qcpzeta[i] = s.pk.trace.Qcp[i].Evaluate(s.zeta) + qcpzeta[i] = s.trace.Qcp[i].Evaluate(s.zeta) wg.Done() }(i) } @@ -750,7 +733,7 @@ func (s *instance) computeLinearizedPolynomial() error { wg.Wait() - s.linearizedPolynomial = computeLinearizedPolynomial( + s.linearizedPolynomial = s.innerComputeLinearizedPoly( blzeta, brzeta, bozeta, @@ -775,9 +758,6 @@ func (s *instance) computeLinearizedPolynomial() error { } func (s *instance) batchOpening() error { - polysQcp := coefficients(s.pk.trace.Qcp) - polysToOpen := make([][]fr.Element, 7+len(polysQcp)) - copy(polysToOpen[7:], polysQcp) // wait for LRO to be committed (or ctx.Done()) select { @@ -800,13 +780,17 @@ func (s *instance) batchOpening() error { case <-s.chLinearizedPolynomial: } + polysQcp := coefficients(s.trace.Qcp) + polysToOpen := make([][]fr.Element, 7+len(polysQcp)) + copy(polysToOpen[7:], polysQcp) + polysToOpen[0] = s.foldedH polysToOpen[1] = s.linearizedPolynomial polysToOpen[2] = getBlindedCoefficients(s.x[id_L], s.bp[id_Bl]) polysToOpen[3] = getBlindedCoefficients(s.x[id_R], s.bp[id_Br]) polysToOpen[4] = getBlindedCoefficients(s.x[id_O], s.bp[id_Bo]) - polysToOpen[5] = s.pk.trace.S1.Coefficients() - polysToOpen[6] = s.pk.trace.S2.Coefficients() + polysToOpen[5] = s.trace.S1.Coefficients() + polysToOpen[6] = s.trace.S2.Coefficients() digestsToOpen := make([]curve.G1Affine, len(s.pk.Vk.Qcp)+7) copy(digestsToOpen[7:], s.pk.Vk.Qcp) @@ -835,6 +819,23 @@ func (s *instance) batchOpening() error { // evaluate the full set of constraints, all polynomials in x are back in // canonical regular form at the end func (s *instance) computeNumerator() (*iop.Polynomial, error) { + // init vectors that are used multiple times throughout the computation + n := s.domain0.Cardinality + twiddles0 := make([]fr.Element, n) + if n == 1 { + // edge case + twiddles0[0].SetOne() + } else { + twiddles, err := s.domain0.Twiddles() + if err != nil { + return nil, err + } + copy(twiddles0, twiddles[0]) + w := twiddles0[1] + for i := len(twiddles[0]); i < len(twiddles0); i++ { + twiddles0[i].Mul(&twiddles0[i-1], &w) + } + } // wait for chQk to be closed (or ctx.Done()) select { @@ -843,8 +844,6 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { case <-s.chQk: } - n := s.pk.Domain[0].Cardinality - nbBsbGates := (len(s.x) - id_Qci + 1) >> 1 gateConstraint := func(u ...fr.Element) fr.Element { @@ -867,7 +866,7 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { } var cs, css fr.Element - cs.Set(&s.pk.Domain[1].FrMultiplicativeGen) + cs.Set(&s.domain1.FrMultiplicativeGen) css.Square(&cs) orderingConstraint := func(u ...fr.Element) fr.Element { @@ -899,11 +898,11 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { return res } - rho := int(s.pk.Domain[1].Cardinality / n) + rho := int(s.domain1.Cardinality / n) shifters := make([]fr.Element, rho) - shifters[0].Set(&s.pk.Domain[1].FrMultiplicativeGen) + shifters[0].Set(&s.domain1.FrMultiplicativeGen) for i := 1; i < rho; i++ { - shifters[i].Set(&s.pk.Domain[1].Generator) + shifters[i].Set(&s.domain1.Generator) } // stores the current coset shifter @@ -914,14 +913,13 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { one.SetOne() bn := big.NewInt(int64(n)) - // wait for init go routine - <-s.chNumeratorInit - - cosetTable := s.pk.Domain[0].CosetTable - twiddles := s.pk.Domain[1].Twiddles[0][:n] + cosetTable, err := s.domain0.CosetTable() + if err != nil { + return nil, err + } // init the result polynomial & buffer - cres := s.cres + cres := make([]fr.Element, s.domain1.Cardinality) buf := make([]fr.Element, n) var wgBuf sync.WaitGroup @@ -933,17 +931,17 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { // blind L, R, O, Z, ZS var y fr.Element - y = s.bp[id_Bl].Evaluate(s.twiddles0[i]) + y = s.bp[id_Bl].Evaluate(twiddles0[i]) u[id_L].Add(&u[id_L], &y) - y = s.bp[id_Br].Evaluate(s.twiddles0[i]) + y = s.bp[id_Br].Evaluate(twiddles0[i]) u[id_R].Add(&u[id_R], &y) - y = s.bp[id_Bo].Evaluate(s.twiddles0[i]) + y = s.bp[id_Bo].Evaluate(twiddles0[i]) u[id_O].Add(&u[id_O], &y) - y = s.bp[id_Bz].Evaluate(s.twiddles0[i]) + y = s.bp[id_Bz].Evaluate(twiddles0[i]) u[id_Z].Add(&u[id_Z], &y) // ZS is shifted by 1; need to get correct twiddle - y = s.bp[id_Bz].Evaluate(s.twiddles0[(i+1)%int(n)]) + y = s.bp[id_Bz].Evaluate(twiddles0[(i+1)%int(n)]) u[id_ZS].Add(&u[id_ZS], &y) a := gateConstraint(u...) @@ -953,28 +951,15 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { return c } - // select the correct scaling vector to scale by shifter[i] - selectScalingVector := func(i int, l iop.Layout) []fr.Element { - var w []fr.Element - if i == 0 { - if l == iop.Regular { - w = cosetTable - } else { - w = s.cosetTableRev - } - } else { - if l == iop.Regular { - w = twiddles - } else { - w = s.twiddlesRev - } - } - return w - } + // for the first iteration, the scalingVector is the coset table + scalingVector := cosetTable + scalingVectorRev := make([]fr.Element, len(cosetTable)) + copy(scalingVectorRev, cosetTable) + fft.BitReverse(scalingVectorRev) // pre-computed to compute the bit reverse index // of the result polynomial - m := uint64(s.pk.Domain[1].Cardinality) + m := uint64(s.domain1.Cardinality) mm := uint64(64 - bits.TrailingZeros64(m)) for i := 0; i < rho; i++ { @@ -991,6 +976,17 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { acc.Mul(&acc, &shifters[i]) } } + if i == 1 { + // we have to update the scalingVector; instead of scaling by + // cosets we scale by the twiddles of the large domain. + w := s.domain1.Generator + scalingVector = make([]fr.Element, n) + fft.BuildExpTable(w, scalingVector) + + // reuse memory + copy(scalingVectorRev, scalingVector) + fft.BitReverse(scalingVectorRev) + } // we do **a lot** of FFT here, but on the small domain. // note that for all the polynomials in the proving key @@ -1000,10 +996,15 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { batchApply(s.x, func(p *iop.Polynomial) { nbTasks := calculateNbTasks(len(s.x)-1) * 2 // shift polynomials to be in the correct coset - p.ToCanonical(&s.pk.Domain[0], nbTasks) + p.ToCanonical(s.domain0, nbTasks) // scale by shifter[i] - w := selectScalingVector(i, p.Layout) + var w []fr.Element + if p.Layout == iop.Regular { + w = scalingVector + } else { + w = scalingVectorRev + } cp := p.Coefficients() utils.Parallelize(len(cp), func(start, end int) { @@ -1013,7 +1014,7 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { }, nbTasks) // fft in the correct coset - p.ToLagrange(&s.pk.Domain[0], nbTasks).ToRegular() + p.ToLagrange(s.domain0, nbTasks).ToRegular() }) wgBuf.Wait() @@ -1046,9 +1047,10 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { // scale everything back go func() { - for i := id_ZS; i < len(s.x); i++ { - s.x[i] = nil - } + s.x[id_ID] = nil + s.x[id_LOne] = nil + s.x[id_ZS] = nil + s.x[id_Qk] = nil var cs fr.Element cs.Set(&shifters[0]) @@ -1057,8 +1059,11 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { } cs.Inverse(&cs) - batchApply(s.x[:id_ZS], func(p *iop.Polynomial) { - p.ToCanonical(&s.pk.Domain[0], 8).ToRegular() + batchApply(s.x, func(p *iop.Polynomial) { + if p == nil { + return + } + p.ToCanonical(s.domain0, 8).ToRegular() scalePowers(p, cs) }) @@ -1275,7 +1280,7 @@ func evaluateXnMinusOneDomainBigCoset(domains [2]*fft.Domain) []fr.Element { return res } -// computeLinearizedPolynomial computes the linearized polynomial in canonical basis. +// innerComputeLinearizedPoly computes the linearized polynomial in canonical basis. // The purpose is to commit and open all in one ql, qr, qm, qo, qk. // * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta // * z is the permutation polynomial, zu is Z(μX), the shifted version of Z @@ -1286,8 +1291,8 @@ func evaluateXnMinusOneDomainBigCoset(domains [2]*fft.Domain) []fr.Element { // α²*L₁(ζ)*Z(X) // + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) // + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) -func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, qcpZeta, blindedZCanonical []fr.Element, pi2Canonical [][]fr.Element, pk *ProvingKey) []fr.Element { - +func (s *instance) innerComputeLinearizedPoly(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, qcpZeta, blindedZCanonical []fr.Element, pi2Canonical [][]fr.Element, pk *ProvingKey) []fr.Element { + // TODO @gbotrel rename // first part: individual constraints var rl fr.Element rl.Mul(&rZeta, &lZeta) @@ -1297,12 +1302,12 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, var s1, s2 fr.Element chS1 := make(chan struct{}, 1) go func() { - s1 = pk.trace.S1.Evaluate(zeta) // s1(ζ) + s1 = s.trace.S1.Evaluate(zeta) // s1(ζ) s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) close(chS1) }() // ps2 := iop.NewPolynomial(&pk.S2Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - tmp := pk.trace.S2.Evaluate(zeta) // s2(ζ) + tmp := s.trace.S2.Evaluate(zeta) // s2(ζ) tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) @@ -1321,7 +1326,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, // third part L₁(ζ)*α²*Z var lagrangeZeta, one, den, frNbElmt fr.Element one.SetOne() - nbElmt := int64(pk.Domain[0].Cardinality) + nbElmt := int64(s.domain0.Cardinality) lagrangeZeta.Set(&zeta). Exp(lagrangeZeta, big.NewInt(nbElmt)). Sub(&lagrangeZeta, &one) @@ -1331,17 +1336,19 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) Mul(&lagrangeZeta, &alpha). Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) + Mul(&lagrangeZeta, &s.domain0.CardinalityInv) // (1/n)*α²*L₁(ζ) + + s3canonical := s.trace.S3.Coefficients() - s3canonical := pk.trace.S3.Coefficients() + s.trace.Qk.ToCanonical(s.domain0).ToRegular() utils.Parallelize(len(blindedZCanonical), func(start, end int) { - cql := pk.trace.Ql.Coefficients() - cqr := pk.trace.Qr.Coefficients() - cqm := pk.trace.Qm.Coefficients() - cqo := pk.trace.Qo.Coefficients() - cqk := pk.trace.Qk.Coefficients() + cql := s.trace.Ql.Coefficients() + cqr := s.trace.Qr.Coefficients() + cqm := s.trace.Qm.Coefficients() + cqo := s.trace.Qo.Coefficients() + cqk := s.trace.Qk.Coefficients() var t, t0, t1 fr.Element diff --git a/backend/plonk/bls12-381/setup.go b/backend/plonk/bls12-381/setup.go index 3fdb43cd52..e0d7e50c06 100644 --- a/backend/plonk/bls12-381/setup.go +++ b/backend/plonk/bls12-381/setup.go @@ -78,34 +78,14 @@ type Trace struct { S []int64 } -// ProvingKey stores the data needed to generate a proof: -// * the commitment scheme -// * ql, prepended with as many ones as they are public inputs -// * qr, qm, qo prepended with as many zeroes as there are public inputs. -// * qk, prepended with as many zeroes as public inputs, to be completed by the prover -// with the list of public inputs. -// * sigma_1, sigma_2, sigma_3 in both basis -// * the copy constraint permutation +// ProvingKey stores the data needed to generate a proof type ProvingKey struct { - // stores ql, qr, qm, qo, qk (-> to be completed by the prover) - // and s1, s2, s3. They are set in canonical basis before generating the proof, they will be used - // for computing the opening proofs (hence the canonical form). The canonical version - // of qk incomplete is used in the linearisation polynomial. - // The polynomials in trace are in canonical basis. - trace Trace - Kzg, KzgLagrange kzg.ProvingKey // Verifying Key is embedded into the proving key (needed by Prove) Vk *VerifyingKey - - // Domains used for the FFTs. - // Domain[0] = small Domain - // Domain[1] = big Domain - Domain [2]fft.Domain } -// TODO modify the signature to receive the SRS in Lagrange form (optional argument ?) func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *VerifyingKey, error) { var pk ProvingKey @@ -114,26 +94,26 @@ func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *Verifyin vk.CommitmentConstraintIndexes = internal.IntSliceToUint64Slice(spr.CommitmentInfo.CommitmentIndexes()) // step 0: set the fft domains - pk.initDomains(spr) - if pk.Domain[0].Cardinality < 2 { + domain := initFFTDomain(spr) + if domain.Cardinality < 2 { return nil, nil, fmt.Errorf("circuit has only %d constraints; unsupported by the current implementation", spr.GetNbConstraints()) } // check the size of the kzg srs. - if len(srs.Pk.G1) < (int(pk.Domain[0].Cardinality) + 3) { // + 3 for the kzg.Open of blinded poly - return nil, nil, fmt.Errorf("kzg srs is too small: got %d, need %d", len(srs.Pk.G1), pk.Domain[0].Cardinality+3) + if len(srs.Pk.G1) < (int(domain.Cardinality) + 3) { // + 3 for the kzg.Open of blinded poly + return nil, nil, fmt.Errorf("kzg srs is too small: got %d, need %d", len(srs.Pk.G1), domain.Cardinality+3) } // same for the lagrange form - if len(srsLagrange.Pk.G1) != int(pk.Domain[0].Cardinality) { - return nil, nil, fmt.Errorf("kzg srs lagrange is too small: got %d, need %d", len(srsLagrange.Pk.G1), pk.Domain[0].Cardinality) + if len(srsLagrange.Pk.G1) != int(domain.Cardinality) { + return nil, nil, fmt.Errorf("kzg srs lagrange is too small: got %d, need %d", len(srsLagrange.Pk.G1), domain.Cardinality) } // step 1: set the verifying key - pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) - vk.Size = pk.Domain[0].Cardinality + vk.CosetShift.Set(&domain.FrMultiplicativeGen) + vk.Size = domain.Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.Domain[0].Generator) + vk.Generator.Set(&domain.Generator) vk.NbPublicVariables = uint64(len(spr.Public)) pk.Kzg.G1 = srs.Pk.G1[:int(vk.Size)+3] @@ -141,22 +121,15 @@ func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *Verifyin vk.Kzg = srs.Vk // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis - BuildTrace(spr, &pk.trace) - // step 3: build the permutation and build the polynomials S1, S2, S3 to encode the permutation. // Note: at this stage, the permutation takes in account the placeholders - nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - buildPermutation(spr, &pk.trace, nbVariables) - s := computePermutationPolynomials(&pk.trace, &pk.Domain[0]) - pk.trace.S1 = s[0] - pk.trace.S2 = s[1] - pk.trace.S3 = s[2] + trace := NewTrace(spr, domain) // step 4: commit to s1, s2, s3, ql, qr, qm, qo, and (the incomplete version of) qk. // All the above polynomials are expressed in canonical basis afterwards. This is why // we save lqk before, because the prover needs to complete it in Lagrange form, and // then express it on the Lagrange coset basis. - if err := commitTrace(&pk.trace, &pk); err != nil { + if err := vk.commitTrace(trace, domain, pk.KzgLagrange); err != nil { return nil, nil, err } @@ -173,9 +146,13 @@ func (pk *ProvingKey) VerifyingKey() interface{} { return pk.Vk } -// BuildTrace fills the constant columns ql, qr, qm, qo, qk from the sparser1cs. -// Size is the size of the system that is nb_constraints+nb_public_variables -func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { +// NewTrace returns a new Trace object from the constraint system. +// It fills the constant columns ql, qr, qm, qo, qk, and qcp with the +// coefficients of the constraints. +// Size is the size of the system that is next power of 2 (nb_constraints+nb_public_variables) +// The permutation is also computed and stored in the Trace. +func NewTrace(spr *cs.SparseR1CS, domain *fft.Domain) *Trace { + var trace Trace nbConstraints := spr.GetNbConstraints() sizeSystem := uint64(nbConstraints + len(spr.Public)) @@ -211,85 +188,75 @@ func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} - pt.Ql = iop.NewPolynomial(&ql, lagReg) - pt.Qr = iop.NewPolynomial(&qr, lagReg) - pt.Qm = iop.NewPolynomial(&qm, lagReg) - pt.Qo = iop.NewPolynomial(&qo, lagReg) - pt.Qk = iop.NewPolynomial(&qk, lagReg) - pt.Qcp = make([]*iop.Polynomial, len(qcp)) + trace.Ql = iop.NewPolynomial(&ql, lagReg) + trace.Qr = iop.NewPolynomial(&qr, lagReg) + trace.Qm = iop.NewPolynomial(&qm, lagReg) + trace.Qo = iop.NewPolynomial(&qo, lagReg) + trace.Qk = iop.NewPolynomial(&qk, lagReg) + trace.Qcp = make([]*iop.Polynomial, len(qcp)) for i := range commitmentInfo { qcp[i] = make([]fr.Element, size) for _, committed := range commitmentInfo[i].Committed { qcp[i][offset+committed].SetOne() } - pt.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) + trace.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) } + + // build the permutation and build the polynomials S1, S2, S3 to encode the permutation. + // Note: at this stage, the permutation takes in account the placeholders + nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + buildPermutation(spr, &trace, nbVariables) + s := computePermutationPolynomials(&trace, domain) + trace.S1 = s[0] + trace.S2 = s[1] + trace.S3 = s[2] + + return &trace } // commitTrace commits to every polynomial in the trace, and put // the commitments int the verifying key. -func commitTrace(trace *Trace, pk *ProvingKey) error { - - trace.Ql.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qr.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qm.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qo.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qk.ToCanonical(&pk.Domain[0]).ToRegular() // -> qk is not complete - trace.S1.ToCanonical(&pk.Domain[0]).ToRegular() - trace.S2.ToCanonical(&pk.Domain[0]).ToRegular() - trace.S3.ToCanonical(&pk.Domain[0]).ToRegular() +func (vk *VerifyingKey) commitTrace(trace *Trace, domain *fft.Domain, srsPk kzg.ProvingKey) error { var err error - pk.Vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) + vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) for i := range trace.Qcp { - trace.Qcp[i].ToCanonical(&pk.Domain[0]).ToRegular() - if pk.Vk.Qcp[i], err = kzg.Commit(pk.trace.Qcp[i].Coefficients(), pk.Kzg); err != nil { + if vk.Qcp[i], err = kzg.Commit(trace.Qcp[i].Coefficients(), srsPk); err != nil { return err } } - if pk.Vk.Ql, err = kzg.Commit(pk.trace.Ql.Coefficients(), pk.Kzg); err != nil { + if vk.Ql, err = kzg.Commit(trace.Ql.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qr, err = kzg.Commit(pk.trace.Qr.Coefficients(), pk.Kzg); err != nil { + if vk.Qr, err = kzg.Commit(trace.Qr.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qm, err = kzg.Commit(pk.trace.Qm.Coefficients(), pk.Kzg); err != nil { + if vk.Qm, err = kzg.Commit(trace.Qm.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qo, err = kzg.Commit(pk.trace.Qo.Coefficients(), pk.Kzg); err != nil { + if vk.Qo, err = kzg.Commit(trace.Qo.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qk, err = kzg.Commit(pk.trace.Qk.Coefficients(), pk.Kzg); err != nil { + if vk.Qk, err = kzg.Commit(trace.Qk.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.S[0], err = kzg.Commit(pk.trace.S1.Coefficients(), pk.Kzg); err != nil { + if vk.S[0], err = kzg.Commit(trace.S1.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.S[1], err = kzg.Commit(pk.trace.S2.Coefficients(), pk.Kzg); err != nil { + if vk.S[1], err = kzg.Commit(trace.S2.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.S[2], err = kzg.Commit(pk.trace.S3.Coefficients(), pk.Kzg); err != nil { + if vk.S[2], err = kzg.Commit(trace.S3.Coefficients(), srsPk); err != nil { return err } return nil } -func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { - +func initFFTDomain(spr *cs.SparseR1CS) *fft.Domain { nbConstraints := spr.GetNbConstraints() sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints - pk.Domain[0] = *fft.NewDomain(sizeSystem) - - // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, - // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases - // except when n<6. - if sizeSystem < 6 { - pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) - } else { - pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) - } - + return fft.NewDomain(sizeSystem, fft.WithoutPrecompute()) } // buildPermutation builds the Permutation associated with a circuit. @@ -304,10 +271,10 @@ func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { // The permutation is encoded as a slice s of size 3*size(l), where the // i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab // like this: for i in tab: tab[i] = tab[permutation[i]] -func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { +func buildPermutation(spr *cs.SparseR1CS, trace *Trace, nbVariables int) { // nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - sizeSolution := len(pt.Ql.Coefficients()) + sizeSolution := len(trace.Ql.Coefficients()) sizePermutation := 3 * sizeSolution // init permutation @@ -357,13 +324,13 @@ func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { } } - pt.S = permutation + trace.S = permutation } // computePermutationPolynomials computes the LDE (Lagrange basis) of the permutation. // We let the permutation act on || u || u^{2}, split the result in 3 parts, // and interpolate each of the 3 parts on . -func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polynomial { +func computePermutationPolynomials(trace *Trace, domain *fft.Domain) [3]*iop.Polynomial { nbElmts := int(domain.Cardinality) @@ -377,9 +344,9 @@ func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polyno s2Canonical := make([]fr.Element, nbElmts) s3Canonical := make([]fr.Element, nbElmts) for i := 0; i < nbElmts; i++ { - s1Canonical[i].Set(&evaluationIDSmallDomain[pt.S[i]]) - s2Canonical[i].Set(&evaluationIDSmallDomain[pt.S[nbElmts+i]]) - s3Canonical[i].Set(&evaluationIDSmallDomain[pt.S[2*nbElmts+i]]) + s1Canonical[i].Set(&evaluationIDSmallDomain[trace.S[i]]) + s2Canonical[i].Set(&evaluationIDSmallDomain[trace.S[nbElmts+i]]) + s3Canonical[i].Set(&evaluationIDSmallDomain[trace.S[2*nbElmts+i]]) } lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} diff --git a/backend/plonk/bls24-315/marshal.go b/backend/plonk/bls24-315/marshal.go index 41d4f69054..2ec858fb7b 100644 --- a/backend/plonk/bls24-315/marshal.go +++ b/backend/plonk/bls24-315/marshal.go @@ -19,10 +19,6 @@ package plonk import ( curve "github.com/consensys/gnark-crypto/ecc/bls24-315" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" - - "errors" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/iop" "github.com/consensys/gnark-crypto/ecc/bls24-315/kzg" "io" ) @@ -116,19 +112,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, withCompression bool) (n int64, err e return } - // fft domains - n2, err := pk.Domain[0].WriteTo(w) - if err != nil { - return - } - n += n2 - - n2, err = pk.Domain[1].WriteTo(w) - if err != nil { - return - } - n += n2 - + var n2 int64 // KZG key if withCompression { n2, err = pk.Kzg.WriteTo(w) @@ -149,35 +133,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, withCompression bool) (n int64, err e } n += n2 - // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) - if len(pk.trace.S) != (3 * int(pk.Domain[0].Cardinality)) { - return n, errors.New("invalid permutation size, expected 3*domain cardinality") - } - - enc := curve.NewEncoder(w) - // note: type Polynomial, which is handled by default binary.Write(...) op and doesn't - // encode the size (nor does it convert from Montgomery to Regular form) - // so we explicitly transmit []fr.Element - toEncode := []interface{}{ - pk.trace.Ql.Coefficients(), - pk.trace.Qr.Coefficients(), - pk.trace.Qm.Coefficients(), - pk.trace.Qo.Coefficients(), - pk.trace.Qk.Coefficients(), - coefficients(pk.trace.Qcp), - pk.trace.S1.Coefficients(), - pk.trace.S2.Coefficients(), - pk.trace.S3.Coefficients(), - pk.trace.S, - } - - for _, v := range toEncode { - if err := enc.Encode(v); err != nil { - return n + enc.BytesWritten(), err - } - } - - return n + enc.BytesWritten(), nil + return n, nil } // ReadFrom reads from binary representation in r into ProvingKey @@ -197,18 +153,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err return n, err } - n2, err, chDomain0 := pk.Domain[0].AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - - n2, err, chDomain1 := pk.Domain[1].AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - + var n2 int64 if withSubgroupChecks { n2, err = pk.Kzg.ReadFrom(r) } else { @@ -224,98 +169,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err n2, err = pk.KzgLagrange.UnsafeReadFrom(r) } n += n2 - if err != nil { - return n, err - } - - pk.trace.S = make([]int64, 3*pk.Domain[0].Cardinality) - - dec := curve.NewDecoder(r) - - var ql, qr, qm, qo, qk, s1, s2, s3 []fr.Element - var qcp [][]fr.Element - - // TODO @gbotrel: this is a bit ugly, we should probably refactor this. - // The order of the variables is important, as it matches the order in which they are - // encoded in the WriteTo(...) method. - - // Note: instead of calling dec.Decode(...) for each of the above variables, - // we call AsyncReadFrom when possible which allows to consume bytes from the reader - // and perform the decoding in parallel - - type v struct { - data *fr.Vector - chErr chan error - } - - vectors := make([]v, 8) - vectors[0] = v{data: (*fr.Vector)(&ql)} - vectors[1] = v{data: (*fr.Vector)(&qr)} - vectors[2] = v{data: (*fr.Vector)(&qm)} - vectors[3] = v{data: (*fr.Vector)(&qo)} - vectors[4] = v{data: (*fr.Vector)(&qk)} - vectors[5] = v{data: (*fr.Vector)(&s1)} - vectors[6] = v{data: (*fr.Vector)(&s2)} - vectors[7] = v{data: (*fr.Vector)(&s3)} - - // read ql, qr, qm, qo, qk - for i := 0; i < 5; i++ { - n2, err, ch := vectors[i].data.AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - vectors[i].chErr = ch - } - - // read qcp - if err := dec.Decode(&qcp); err != nil { - return n + dec.BytesRead(), err - } - - // read lqk, s1, s2, s3 - for i := 5; i < 8; i++ { - n2, err, ch := vectors[i].data.AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - vectors[i].chErr = ch - } - - // read pk.Trace.S - if err := dec.Decode(&pk.trace.S); err != nil { - return n + dec.BytesRead(), err - } - - // wait for all AsyncReadFrom(...) to complete - for i := range vectors { - if err := <-vectors[i].chErr; err != nil { - return n, err - } - } - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - pk.trace.Ql = iop.NewPolynomial(&ql, canReg) - pk.trace.Qr = iop.NewPolynomial(&qr, canReg) - pk.trace.Qm = iop.NewPolynomial(&qm, canReg) - pk.trace.Qo = iop.NewPolynomial(&qo, canReg) - pk.trace.Qk = iop.NewPolynomial(&qk, canReg) - pk.trace.S1 = iop.NewPolynomial(&s1, canReg) - pk.trace.S2 = iop.NewPolynomial(&s2, canReg) - pk.trace.S3 = iop.NewPolynomial(&s3, canReg) - - pk.trace.Qcp = make([]*iop.Polynomial, len(qcp)) - for i := range qcp { - pk.trace.Qcp[i] = iop.NewPolynomial(&qcp[i], canReg) - } - - // wait for FFT to be precomputed - <-chDomain0 - <-chDomain1 - - return n + dec.BytesRead(), nil - + return n, err } // WriteTo writes binary encoding of VerifyingKey to w diff --git a/backend/plonk/bls24-315/marshal_test.go b/backend/plonk/bls24-315/marshal_test.go index 29cf02bddb..52ce8d71c4 100644 --- a/backend/plonk/bls24-315/marshal_test.go +++ b/backend/plonk/bls24-315/marshal_test.go @@ -20,9 +20,6 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bls24-315" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" - - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/iop" "github.com/consensys/gnark/io" "math/big" "math/rand" @@ -60,8 +57,6 @@ func (pk *ProvingKey) randomize() { var vk VerifyingKey vk.randomize() pk.Vk = &vk - pk.Domain[0] = *fft.NewDomain(32) - pk.Domain[1] = *fft.NewDomain(4 * 32) pk.Kzg.G1 = make([]curve.G1Affine, 32) pk.KzgLagrange.G1 = make([]curve.G1Affine, 32) @@ -70,36 +65,6 @@ func (pk *ProvingKey) randomize() { pk.KzgLagrange.G1[i] = randomG1Point() } - n := int(pk.Domain[0].Cardinality) - ql := randomScalars(n) - qr := randomScalars(n) - qm := randomScalars(n) - qo := randomScalars(n) - qk := randomScalars(n) - s1 := randomScalars(n) - s2 := randomScalars(n) - s3 := randomScalars(n) - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - pk.trace.Ql = iop.NewPolynomial(&ql, canReg) - pk.trace.Qr = iop.NewPolynomial(&qr, canReg) - pk.trace.Qm = iop.NewPolynomial(&qm, canReg) - pk.trace.Qo = iop.NewPolynomial(&qo, canReg) - pk.trace.Qk = iop.NewPolynomial(&qk, canReg) - pk.trace.S1 = iop.NewPolynomial(&s1, canReg) - pk.trace.S2 = iop.NewPolynomial(&s2, canReg) - pk.trace.S3 = iop.NewPolynomial(&s3, canReg) - - pk.trace.Qcp = make([]*iop.Polynomial, rand.Intn(4)) //#nosec G404 weak rng is fine here - for i := range pk.trace.Qcp { - qcp := randomScalars(rand.Intn(n / 4)) //#nosec G404 weak rng is fine here - pk.trace.Qcp[i] = iop.NewPolynomial(&qcp, canReg) - } - - pk.trace.S = make([]int64, 3*pk.Domain[0].Cardinality) - pk.trace.S[0] = -12 - pk.trace.S[len(pk.trace.S)-1] = 8888 - } func (vk *VerifyingKey) randomize() { diff --git a/backend/plonk/bls24-315/prove.go b/backend/plonk/bls24-315/prove.go index 95161e9ede..aa4ff21213 100644 --- a/backend/plonk/bls24-315/prove.go +++ b/backend/plonk/bls24-315/prove.go @@ -138,9 +138,6 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts // solve constraints g.Go(instance.solveConstraints) - // compute numerator data - g.Go(instance.initComputeNumerator) - // complete qk g.Go(instance.completeQk) @@ -211,9 +208,6 @@ type instance struct { // challenges gamma, beta, alpha, zeta fr.Element - // compute numerator data - cres, twiddles0, cosetTableRev, twiddlesRev []fr.Element - // channel to wait for the steps chLRO, chQk, @@ -224,8 +218,11 @@ type instance struct { chZOpening, chLinearizedPolynomial, chFoldedH, - chNumeratorInit, chGammaBeta chan struct{} + + domain0, domain1 *fft.Domain + + trace *Trace } func newInstance(ctx context.Context, spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts *backend.ProverConfig) (*instance, error) { @@ -253,43 +250,31 @@ func newInstance(ctx context.Context, spr *cs.SparseR1CS, pk *ProvingKey, fullWi chLinearizedPolynomial: make(chan struct{}, 1), chFoldedH: make(chan struct{}, 1), chRestoreLRO: make(chan struct{}, 1), - chNumeratorInit: make(chan struct{}, 1), } s.initBSB22Commitments() s.setupGKRHints() s.x = make([]*iop.Polynomial, id_Qci+2*len(s.commitmentInfo)) - return &s, nil -} + // init fft domains + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints + s.domain0 = fft.NewDomain(sizeSystem) -func (s *instance) initComputeNumerator() error { - n := s.pk.Domain[0].Cardinality - s.cres = make([]fr.Element, s.pk.Domain[1].Cardinality) - s.twiddles0 = make([]fr.Element, n) - if n == 1 { - // edge case - s.twiddles0[0].SetOne() + // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, + // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases + // except when n<6. + if sizeSystem < 6 { + s.domain1 = fft.NewDomain(8*sizeSystem, fft.WithoutPrecompute()) } else { - copy(s.twiddles0, s.pk.Domain[0].Twiddles[0]) - for i := len(s.pk.Domain[0].Twiddles[0]); i < len(s.twiddles0); i++ { - s.twiddles0[i].Mul(&s.twiddles0[i-1], &s.twiddles0[1]) - } + s.domain1 = fft.NewDomain(4*sizeSystem, fft.WithoutPrecompute()) } + // TODO @gbotrel domain1 is used for only 1 FFT --> precomputing the twiddles + // and storing them in memory is costly given its size. --> do a FFT on the fly - cosetTable := s.pk.Domain[0].CosetTable - twiddles := s.pk.Domain[1].Twiddles[0][:n] - - s.cosetTableRev = make([]fr.Element, len(cosetTable)) - copy(s.cosetTableRev, cosetTable) - fft.BitReverse(s.cosetTableRev) - - s.twiddlesRev = make([]fr.Element, len(twiddles)) - copy(s.twiddlesRev, twiddles) - fft.BitReverse(s.twiddlesRev) - - close(s.chNumeratorInit) + // build trace + s.trace = NewTrace(spr, s.domain0) - return nil + return &s, nil } func (s *instance) initBlindingPolynomials() error { @@ -321,7 +306,7 @@ func (s *instance) bsb22Hint(_ *big.Int, ins, outs []*big.Int) error { res := &s.commitmentVal[commDepth] commitmentInfo := s.spr.CommitmentInfo.(constraint.PlonkCommitments)[commDepth] - committedValues := make([]fr.Element, s.pk.Domain[0].Cardinality) + committedValues := make([]fr.Element, s.domain0.Cardinality) offset := s.spr.GetNbPublicVariables() for i := range ins { committedValues[offset+commitmentInfo.Committed[i]].SetBigInt(ins[i]) @@ -336,7 +321,6 @@ func (s *instance) bsb22Hint(_ *big.Int, ins, outs []*big.Int) error { if s.proof.Bsb22Commitments[commDepth], err = kzg.Commit(s.cCommitments[commDepth].Coefficients(), s.pk.KzgLagrange); err != nil { return err } - s.cCommitments[commDepth].ToCanonical(&s.pk.Domain[0]).ToRegular() s.htfFunc.Write(s.proof.Bsb22Commitments[commDepth].Marshal()) hashBts := s.htfFunc.Sum(nil) @@ -395,7 +379,7 @@ func (s *instance) solveConstraints() error { } func (s *instance) completeQk() error { - qk := s.pk.trace.Qk.Clone().ToLagrange(&s.pk.Domain[0]).ToRegular() + qk := s.trace.Qk.Clone() qkCoeffs := qk.Coefficients() wWitness, ok := s.fullWitness.Vector().(fr.Vector) @@ -494,7 +478,7 @@ func (s *instance) commitToPolyAndBlinding(p, b *iop.Polynomial) (commit curve.G commit, err = kzg.Commit(p.Coefficients(), s.pk.KzgLagrange) // we add in the blinding contribution - n := int(s.pk.Domain[0].Cardinality) + n := int(s.domain0.Cardinality) cb := commitBlindingFactor(n, b, s.pk.Kzg) commit.Add(&commit, &cb) @@ -518,20 +502,19 @@ func (s *instance) deriveZeta() (err error) { // evaluateConstraints computes H func (s *instance) evaluateConstraints() (err error) { - // clone polys from the proving key. - s.x[id_Ql] = s.pk.trace.Ql.Clone() - s.x[id_Qr] = s.pk.trace.Qr.Clone() - s.x[id_Qm] = s.pk.trace.Qm.Clone() - s.x[id_Qo] = s.pk.trace.Qo.Clone() - s.x[id_S1] = s.pk.trace.S1.Clone() - s.x[id_S2] = s.pk.trace.S2.Clone() - s.x[id_S3] = s.pk.trace.S3.Clone() + s.x[id_Ql] = s.trace.Ql + s.x[id_Qr] = s.trace.Qr + s.x[id_Qm] = s.trace.Qm + s.x[id_Qo] = s.trace.Qo + s.x[id_S1] = s.trace.S1 + s.x[id_S2] = s.trace.S2 + s.x[id_S3] = s.trace.S3 for i := 0; i < len(s.commitmentInfo); i++ { - s.x[id_Qci+2*i] = s.pk.trace.Qcp[i].Clone() + s.x[id_Qci+2*i] = s.trace.Qcp[i] } - n := s.pk.Domain[0].Cardinality + n := s.domain0.Cardinality lone := make([]fr.Element, n) lone[0].SetOne() @@ -543,7 +526,7 @@ func (s *instance) evaluateConstraints() (err error) { } for i := 0; i < len(s.commitmentInfo); i++ { - s.x[id_Qci+2*i+1] = s.cCommitments[i].Clone() + s.x[id_Qci+2*i+1] = s.cCommitments[i] } // wait for Z to be committed or context done @@ -571,7 +554,7 @@ func (s *instance) evaluateConstraints() (err error) { return err } - s.h, err = divideByXMinusOne(numerator, [2]*fft.Domain{&s.pk.Domain[0], &s.pk.Domain[1]}) + s.h, err = divideByXMinusOne(numerator, [2]*fft.Domain{s.domain0, s.domain1}) if err != nil { return err } @@ -613,11 +596,11 @@ func (s *instance) buildRatioCopyConstraint() (err error) { s.x[id_R], s.x[id_O], }, - s.pk.trace.S, + s.trace.S, s.beta, s.gamma, iop.Form{Basis: iop.Lagrange, Layout: iop.Regular}, - &s.pk.Domain[0], + s.domain0, ) if err != nil { return err @@ -652,17 +635,17 @@ func (s *instance) openZ() (err error) { } func (s *instance) h1() []fr.Element { - h1 := s.h.Coefficients()[:s.pk.Domain[0].Cardinality+2] + h1 := s.h.Coefficients()[:s.domain0.Cardinality+2] return h1 } func (s *instance) h2() []fr.Element { - h2 := s.h.Coefficients()[s.pk.Domain[0].Cardinality+2 : 2*(s.pk.Domain[0].Cardinality+2)] + h2 := s.h.Coefficients()[s.domain0.Cardinality+2 : 2*(s.domain0.Cardinality+2)] return h2 } func (s *instance) h3() []fr.Element { - h3 := s.h.Coefficients()[2*(s.pk.Domain[0].Cardinality+2) : 3*(s.pk.Domain[0].Cardinality+2)] + h3 := s.h.Coefficients()[2*(s.domain0.Cardinality+2) : 3*(s.domain0.Cardinality+2)] return h3 } @@ -675,7 +658,7 @@ func (s *instance) foldH() error { case <-s.chH: } var n big.Int - n.SetUint64(s.pk.Domain[0].Cardinality + 2) + n.SetUint64(s.domain0.Cardinality + 2) var zetaPowerNplusTwo fr.Element zetaPowerNplusTwo.Exp(s.zeta, &n) @@ -691,7 +674,7 @@ func (s *instance) foldH() error { h2 := s.h2() s.foldedH = s.h3() - for i := 0; i < int(s.pk.Domain[0].Cardinality)+2; i++ { + for i := 0; i < int(s.domain0.Cardinality)+2; i++ { s.foldedH[i]. Mul(&s.foldedH[i], &zetaPowerNplusTwo). Add(&s.foldedH[i], &h2[i]). @@ -720,7 +703,7 @@ func (s *instance) computeLinearizedPolynomial() error { for i := 0; i < len(s.commitmentInfo); i++ { go func(i int) { - qcpzeta[i] = s.pk.trace.Qcp[i].Evaluate(s.zeta) + qcpzeta[i] = s.trace.Qcp[i].Evaluate(s.zeta) wg.Done() }(i) } @@ -750,7 +733,7 @@ func (s *instance) computeLinearizedPolynomial() error { wg.Wait() - s.linearizedPolynomial = computeLinearizedPolynomial( + s.linearizedPolynomial = s.innerComputeLinearizedPoly( blzeta, brzeta, bozeta, @@ -775,9 +758,6 @@ func (s *instance) computeLinearizedPolynomial() error { } func (s *instance) batchOpening() error { - polysQcp := coefficients(s.pk.trace.Qcp) - polysToOpen := make([][]fr.Element, 7+len(polysQcp)) - copy(polysToOpen[7:], polysQcp) // wait for LRO to be committed (or ctx.Done()) select { @@ -800,13 +780,17 @@ func (s *instance) batchOpening() error { case <-s.chLinearizedPolynomial: } + polysQcp := coefficients(s.trace.Qcp) + polysToOpen := make([][]fr.Element, 7+len(polysQcp)) + copy(polysToOpen[7:], polysQcp) + polysToOpen[0] = s.foldedH polysToOpen[1] = s.linearizedPolynomial polysToOpen[2] = getBlindedCoefficients(s.x[id_L], s.bp[id_Bl]) polysToOpen[3] = getBlindedCoefficients(s.x[id_R], s.bp[id_Br]) polysToOpen[4] = getBlindedCoefficients(s.x[id_O], s.bp[id_Bo]) - polysToOpen[5] = s.pk.trace.S1.Coefficients() - polysToOpen[6] = s.pk.trace.S2.Coefficients() + polysToOpen[5] = s.trace.S1.Coefficients() + polysToOpen[6] = s.trace.S2.Coefficients() digestsToOpen := make([]curve.G1Affine, len(s.pk.Vk.Qcp)+7) copy(digestsToOpen[7:], s.pk.Vk.Qcp) @@ -835,6 +819,23 @@ func (s *instance) batchOpening() error { // evaluate the full set of constraints, all polynomials in x are back in // canonical regular form at the end func (s *instance) computeNumerator() (*iop.Polynomial, error) { + // init vectors that are used multiple times throughout the computation + n := s.domain0.Cardinality + twiddles0 := make([]fr.Element, n) + if n == 1 { + // edge case + twiddles0[0].SetOne() + } else { + twiddles, err := s.domain0.Twiddles() + if err != nil { + return nil, err + } + copy(twiddles0, twiddles[0]) + w := twiddles0[1] + for i := len(twiddles[0]); i < len(twiddles0); i++ { + twiddles0[i].Mul(&twiddles0[i-1], &w) + } + } // wait for chQk to be closed (or ctx.Done()) select { @@ -843,8 +844,6 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { case <-s.chQk: } - n := s.pk.Domain[0].Cardinality - nbBsbGates := (len(s.x) - id_Qci + 1) >> 1 gateConstraint := func(u ...fr.Element) fr.Element { @@ -867,7 +866,7 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { } var cs, css fr.Element - cs.Set(&s.pk.Domain[1].FrMultiplicativeGen) + cs.Set(&s.domain1.FrMultiplicativeGen) css.Square(&cs) orderingConstraint := func(u ...fr.Element) fr.Element { @@ -899,11 +898,11 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { return res } - rho := int(s.pk.Domain[1].Cardinality / n) + rho := int(s.domain1.Cardinality / n) shifters := make([]fr.Element, rho) - shifters[0].Set(&s.pk.Domain[1].FrMultiplicativeGen) + shifters[0].Set(&s.domain1.FrMultiplicativeGen) for i := 1; i < rho; i++ { - shifters[i].Set(&s.pk.Domain[1].Generator) + shifters[i].Set(&s.domain1.Generator) } // stores the current coset shifter @@ -914,14 +913,13 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { one.SetOne() bn := big.NewInt(int64(n)) - // wait for init go routine - <-s.chNumeratorInit - - cosetTable := s.pk.Domain[0].CosetTable - twiddles := s.pk.Domain[1].Twiddles[0][:n] + cosetTable, err := s.domain0.CosetTable() + if err != nil { + return nil, err + } // init the result polynomial & buffer - cres := s.cres + cres := make([]fr.Element, s.domain1.Cardinality) buf := make([]fr.Element, n) var wgBuf sync.WaitGroup @@ -933,17 +931,17 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { // blind L, R, O, Z, ZS var y fr.Element - y = s.bp[id_Bl].Evaluate(s.twiddles0[i]) + y = s.bp[id_Bl].Evaluate(twiddles0[i]) u[id_L].Add(&u[id_L], &y) - y = s.bp[id_Br].Evaluate(s.twiddles0[i]) + y = s.bp[id_Br].Evaluate(twiddles0[i]) u[id_R].Add(&u[id_R], &y) - y = s.bp[id_Bo].Evaluate(s.twiddles0[i]) + y = s.bp[id_Bo].Evaluate(twiddles0[i]) u[id_O].Add(&u[id_O], &y) - y = s.bp[id_Bz].Evaluate(s.twiddles0[i]) + y = s.bp[id_Bz].Evaluate(twiddles0[i]) u[id_Z].Add(&u[id_Z], &y) // ZS is shifted by 1; need to get correct twiddle - y = s.bp[id_Bz].Evaluate(s.twiddles0[(i+1)%int(n)]) + y = s.bp[id_Bz].Evaluate(twiddles0[(i+1)%int(n)]) u[id_ZS].Add(&u[id_ZS], &y) a := gateConstraint(u...) @@ -953,28 +951,15 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { return c } - // select the correct scaling vector to scale by shifter[i] - selectScalingVector := func(i int, l iop.Layout) []fr.Element { - var w []fr.Element - if i == 0 { - if l == iop.Regular { - w = cosetTable - } else { - w = s.cosetTableRev - } - } else { - if l == iop.Regular { - w = twiddles - } else { - w = s.twiddlesRev - } - } - return w - } + // for the first iteration, the scalingVector is the coset table + scalingVector := cosetTable + scalingVectorRev := make([]fr.Element, len(cosetTable)) + copy(scalingVectorRev, cosetTable) + fft.BitReverse(scalingVectorRev) // pre-computed to compute the bit reverse index // of the result polynomial - m := uint64(s.pk.Domain[1].Cardinality) + m := uint64(s.domain1.Cardinality) mm := uint64(64 - bits.TrailingZeros64(m)) for i := 0; i < rho; i++ { @@ -991,6 +976,17 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { acc.Mul(&acc, &shifters[i]) } } + if i == 1 { + // we have to update the scalingVector; instead of scaling by + // cosets we scale by the twiddles of the large domain. + w := s.domain1.Generator + scalingVector = make([]fr.Element, n) + fft.BuildExpTable(w, scalingVector) + + // reuse memory + copy(scalingVectorRev, scalingVector) + fft.BitReverse(scalingVectorRev) + } // we do **a lot** of FFT here, but on the small domain. // note that for all the polynomials in the proving key @@ -1000,10 +996,15 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { batchApply(s.x, func(p *iop.Polynomial) { nbTasks := calculateNbTasks(len(s.x)-1) * 2 // shift polynomials to be in the correct coset - p.ToCanonical(&s.pk.Domain[0], nbTasks) + p.ToCanonical(s.domain0, nbTasks) // scale by shifter[i] - w := selectScalingVector(i, p.Layout) + var w []fr.Element + if p.Layout == iop.Regular { + w = scalingVector + } else { + w = scalingVectorRev + } cp := p.Coefficients() utils.Parallelize(len(cp), func(start, end int) { @@ -1013,7 +1014,7 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { }, nbTasks) // fft in the correct coset - p.ToLagrange(&s.pk.Domain[0], nbTasks).ToRegular() + p.ToLagrange(s.domain0, nbTasks).ToRegular() }) wgBuf.Wait() @@ -1046,9 +1047,10 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { // scale everything back go func() { - for i := id_ZS; i < len(s.x); i++ { - s.x[i] = nil - } + s.x[id_ID] = nil + s.x[id_LOne] = nil + s.x[id_ZS] = nil + s.x[id_Qk] = nil var cs fr.Element cs.Set(&shifters[0]) @@ -1057,8 +1059,11 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { } cs.Inverse(&cs) - batchApply(s.x[:id_ZS], func(p *iop.Polynomial) { - p.ToCanonical(&s.pk.Domain[0], 8).ToRegular() + batchApply(s.x, func(p *iop.Polynomial) { + if p == nil { + return + } + p.ToCanonical(s.domain0, 8).ToRegular() scalePowers(p, cs) }) @@ -1275,7 +1280,7 @@ func evaluateXnMinusOneDomainBigCoset(domains [2]*fft.Domain) []fr.Element { return res } -// computeLinearizedPolynomial computes the linearized polynomial in canonical basis. +// innerComputeLinearizedPoly computes the linearized polynomial in canonical basis. // The purpose is to commit and open all in one ql, qr, qm, qo, qk. // * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta // * z is the permutation polynomial, zu is Z(μX), the shifted version of Z @@ -1286,8 +1291,8 @@ func evaluateXnMinusOneDomainBigCoset(domains [2]*fft.Domain) []fr.Element { // α²*L₁(ζ)*Z(X) // + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) // + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) -func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, qcpZeta, blindedZCanonical []fr.Element, pi2Canonical [][]fr.Element, pk *ProvingKey) []fr.Element { - +func (s *instance) innerComputeLinearizedPoly(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, qcpZeta, blindedZCanonical []fr.Element, pi2Canonical [][]fr.Element, pk *ProvingKey) []fr.Element { + // TODO @gbotrel rename // first part: individual constraints var rl fr.Element rl.Mul(&rZeta, &lZeta) @@ -1297,12 +1302,12 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, var s1, s2 fr.Element chS1 := make(chan struct{}, 1) go func() { - s1 = pk.trace.S1.Evaluate(zeta) // s1(ζ) + s1 = s.trace.S1.Evaluate(zeta) // s1(ζ) s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) close(chS1) }() // ps2 := iop.NewPolynomial(&pk.S2Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - tmp := pk.trace.S2.Evaluate(zeta) // s2(ζ) + tmp := s.trace.S2.Evaluate(zeta) // s2(ζ) tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) @@ -1321,7 +1326,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, // third part L₁(ζ)*α²*Z var lagrangeZeta, one, den, frNbElmt fr.Element one.SetOne() - nbElmt := int64(pk.Domain[0].Cardinality) + nbElmt := int64(s.domain0.Cardinality) lagrangeZeta.Set(&zeta). Exp(lagrangeZeta, big.NewInt(nbElmt)). Sub(&lagrangeZeta, &one) @@ -1331,17 +1336,19 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) Mul(&lagrangeZeta, &alpha). Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) + Mul(&lagrangeZeta, &s.domain0.CardinalityInv) // (1/n)*α²*L₁(ζ) + + s3canonical := s.trace.S3.Coefficients() - s3canonical := pk.trace.S3.Coefficients() + s.trace.Qk.ToCanonical(s.domain0).ToRegular() utils.Parallelize(len(blindedZCanonical), func(start, end int) { - cql := pk.trace.Ql.Coefficients() - cqr := pk.trace.Qr.Coefficients() - cqm := pk.trace.Qm.Coefficients() - cqo := pk.trace.Qo.Coefficients() - cqk := pk.trace.Qk.Coefficients() + cql := s.trace.Ql.Coefficients() + cqr := s.trace.Qr.Coefficients() + cqm := s.trace.Qm.Coefficients() + cqo := s.trace.Qo.Coefficients() + cqk := s.trace.Qk.Coefficients() var t, t0, t1 fr.Element diff --git a/backend/plonk/bls24-315/setup.go b/backend/plonk/bls24-315/setup.go index 2a2756991a..8035bbed8a 100644 --- a/backend/plonk/bls24-315/setup.go +++ b/backend/plonk/bls24-315/setup.go @@ -78,34 +78,14 @@ type Trace struct { S []int64 } -// ProvingKey stores the data needed to generate a proof: -// * the commitment scheme -// * ql, prepended with as many ones as they are public inputs -// * qr, qm, qo prepended with as many zeroes as there are public inputs. -// * qk, prepended with as many zeroes as public inputs, to be completed by the prover -// with the list of public inputs. -// * sigma_1, sigma_2, sigma_3 in both basis -// * the copy constraint permutation +// ProvingKey stores the data needed to generate a proof type ProvingKey struct { - // stores ql, qr, qm, qo, qk (-> to be completed by the prover) - // and s1, s2, s3. They are set in canonical basis before generating the proof, they will be used - // for computing the opening proofs (hence the canonical form). The canonical version - // of qk incomplete is used in the linearisation polynomial. - // The polynomials in trace are in canonical basis. - trace Trace - Kzg, KzgLagrange kzg.ProvingKey // Verifying Key is embedded into the proving key (needed by Prove) Vk *VerifyingKey - - // Domains used for the FFTs. - // Domain[0] = small Domain - // Domain[1] = big Domain - Domain [2]fft.Domain } -// TODO modify the signature to receive the SRS in Lagrange form (optional argument ?) func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *VerifyingKey, error) { var pk ProvingKey @@ -114,26 +94,26 @@ func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *Verifyin vk.CommitmentConstraintIndexes = internal.IntSliceToUint64Slice(spr.CommitmentInfo.CommitmentIndexes()) // step 0: set the fft domains - pk.initDomains(spr) - if pk.Domain[0].Cardinality < 2 { + domain := initFFTDomain(spr) + if domain.Cardinality < 2 { return nil, nil, fmt.Errorf("circuit has only %d constraints; unsupported by the current implementation", spr.GetNbConstraints()) } // check the size of the kzg srs. - if len(srs.Pk.G1) < (int(pk.Domain[0].Cardinality) + 3) { // + 3 for the kzg.Open of blinded poly - return nil, nil, fmt.Errorf("kzg srs is too small: got %d, need %d", len(srs.Pk.G1), pk.Domain[0].Cardinality+3) + if len(srs.Pk.G1) < (int(domain.Cardinality) + 3) { // + 3 for the kzg.Open of blinded poly + return nil, nil, fmt.Errorf("kzg srs is too small: got %d, need %d", len(srs.Pk.G1), domain.Cardinality+3) } // same for the lagrange form - if len(srsLagrange.Pk.G1) != int(pk.Domain[0].Cardinality) { - return nil, nil, fmt.Errorf("kzg srs lagrange is too small: got %d, need %d", len(srsLagrange.Pk.G1), pk.Domain[0].Cardinality) + if len(srsLagrange.Pk.G1) != int(domain.Cardinality) { + return nil, nil, fmt.Errorf("kzg srs lagrange is too small: got %d, need %d", len(srsLagrange.Pk.G1), domain.Cardinality) } // step 1: set the verifying key - pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) - vk.Size = pk.Domain[0].Cardinality + vk.CosetShift.Set(&domain.FrMultiplicativeGen) + vk.Size = domain.Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.Domain[0].Generator) + vk.Generator.Set(&domain.Generator) vk.NbPublicVariables = uint64(len(spr.Public)) pk.Kzg.G1 = srs.Pk.G1[:int(vk.Size)+3] @@ -141,22 +121,15 @@ func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *Verifyin vk.Kzg = srs.Vk // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis - BuildTrace(spr, &pk.trace) - // step 3: build the permutation and build the polynomials S1, S2, S3 to encode the permutation. // Note: at this stage, the permutation takes in account the placeholders - nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - buildPermutation(spr, &pk.trace, nbVariables) - s := computePermutationPolynomials(&pk.trace, &pk.Domain[0]) - pk.trace.S1 = s[0] - pk.trace.S2 = s[1] - pk.trace.S3 = s[2] + trace := NewTrace(spr, domain) // step 4: commit to s1, s2, s3, ql, qr, qm, qo, and (the incomplete version of) qk. // All the above polynomials are expressed in canonical basis afterwards. This is why // we save lqk before, because the prover needs to complete it in Lagrange form, and // then express it on the Lagrange coset basis. - if err := commitTrace(&pk.trace, &pk); err != nil { + if err := vk.commitTrace(trace, domain, pk.KzgLagrange); err != nil { return nil, nil, err } @@ -173,9 +146,13 @@ func (pk *ProvingKey) VerifyingKey() interface{} { return pk.Vk } -// BuildTrace fills the constant columns ql, qr, qm, qo, qk from the sparser1cs. -// Size is the size of the system that is nb_constraints+nb_public_variables -func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { +// NewTrace returns a new Trace object from the constraint system. +// It fills the constant columns ql, qr, qm, qo, qk, and qcp with the +// coefficients of the constraints. +// Size is the size of the system that is next power of 2 (nb_constraints+nb_public_variables) +// The permutation is also computed and stored in the Trace. +func NewTrace(spr *cs.SparseR1CS, domain *fft.Domain) *Trace { + var trace Trace nbConstraints := spr.GetNbConstraints() sizeSystem := uint64(nbConstraints + len(spr.Public)) @@ -211,85 +188,75 @@ func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} - pt.Ql = iop.NewPolynomial(&ql, lagReg) - pt.Qr = iop.NewPolynomial(&qr, lagReg) - pt.Qm = iop.NewPolynomial(&qm, lagReg) - pt.Qo = iop.NewPolynomial(&qo, lagReg) - pt.Qk = iop.NewPolynomial(&qk, lagReg) - pt.Qcp = make([]*iop.Polynomial, len(qcp)) + trace.Ql = iop.NewPolynomial(&ql, lagReg) + trace.Qr = iop.NewPolynomial(&qr, lagReg) + trace.Qm = iop.NewPolynomial(&qm, lagReg) + trace.Qo = iop.NewPolynomial(&qo, lagReg) + trace.Qk = iop.NewPolynomial(&qk, lagReg) + trace.Qcp = make([]*iop.Polynomial, len(qcp)) for i := range commitmentInfo { qcp[i] = make([]fr.Element, size) for _, committed := range commitmentInfo[i].Committed { qcp[i][offset+committed].SetOne() } - pt.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) + trace.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) } + + // build the permutation and build the polynomials S1, S2, S3 to encode the permutation. + // Note: at this stage, the permutation takes in account the placeholders + nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + buildPermutation(spr, &trace, nbVariables) + s := computePermutationPolynomials(&trace, domain) + trace.S1 = s[0] + trace.S2 = s[1] + trace.S3 = s[2] + + return &trace } // commitTrace commits to every polynomial in the trace, and put // the commitments int the verifying key. -func commitTrace(trace *Trace, pk *ProvingKey) error { - - trace.Ql.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qr.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qm.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qo.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qk.ToCanonical(&pk.Domain[0]).ToRegular() // -> qk is not complete - trace.S1.ToCanonical(&pk.Domain[0]).ToRegular() - trace.S2.ToCanonical(&pk.Domain[0]).ToRegular() - trace.S3.ToCanonical(&pk.Domain[0]).ToRegular() +func (vk *VerifyingKey) commitTrace(trace *Trace, domain *fft.Domain, srsPk kzg.ProvingKey) error { var err error - pk.Vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) + vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) for i := range trace.Qcp { - trace.Qcp[i].ToCanonical(&pk.Domain[0]).ToRegular() - if pk.Vk.Qcp[i], err = kzg.Commit(pk.trace.Qcp[i].Coefficients(), pk.Kzg); err != nil { + if vk.Qcp[i], err = kzg.Commit(trace.Qcp[i].Coefficients(), srsPk); err != nil { return err } } - if pk.Vk.Ql, err = kzg.Commit(pk.trace.Ql.Coefficients(), pk.Kzg); err != nil { + if vk.Ql, err = kzg.Commit(trace.Ql.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qr, err = kzg.Commit(pk.trace.Qr.Coefficients(), pk.Kzg); err != nil { + if vk.Qr, err = kzg.Commit(trace.Qr.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qm, err = kzg.Commit(pk.trace.Qm.Coefficients(), pk.Kzg); err != nil { + if vk.Qm, err = kzg.Commit(trace.Qm.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qo, err = kzg.Commit(pk.trace.Qo.Coefficients(), pk.Kzg); err != nil { + if vk.Qo, err = kzg.Commit(trace.Qo.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qk, err = kzg.Commit(pk.trace.Qk.Coefficients(), pk.Kzg); err != nil { + if vk.Qk, err = kzg.Commit(trace.Qk.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.S[0], err = kzg.Commit(pk.trace.S1.Coefficients(), pk.Kzg); err != nil { + if vk.S[0], err = kzg.Commit(trace.S1.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.S[1], err = kzg.Commit(pk.trace.S2.Coefficients(), pk.Kzg); err != nil { + if vk.S[1], err = kzg.Commit(trace.S2.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.S[2], err = kzg.Commit(pk.trace.S3.Coefficients(), pk.Kzg); err != nil { + if vk.S[2], err = kzg.Commit(trace.S3.Coefficients(), srsPk); err != nil { return err } return nil } -func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { - +func initFFTDomain(spr *cs.SparseR1CS) *fft.Domain { nbConstraints := spr.GetNbConstraints() sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints - pk.Domain[0] = *fft.NewDomain(sizeSystem) - - // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, - // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases - // except when n<6. - if sizeSystem < 6 { - pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) - } else { - pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) - } - + return fft.NewDomain(sizeSystem, fft.WithoutPrecompute()) } // buildPermutation builds the Permutation associated with a circuit. @@ -304,10 +271,10 @@ func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { // The permutation is encoded as a slice s of size 3*size(l), where the // i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab // like this: for i in tab: tab[i] = tab[permutation[i]] -func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { +func buildPermutation(spr *cs.SparseR1CS, trace *Trace, nbVariables int) { // nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - sizeSolution := len(pt.Ql.Coefficients()) + sizeSolution := len(trace.Ql.Coefficients()) sizePermutation := 3 * sizeSolution // init permutation @@ -357,13 +324,13 @@ func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { } } - pt.S = permutation + trace.S = permutation } // computePermutationPolynomials computes the LDE (Lagrange basis) of the permutation. // We let the permutation act on || u || u^{2}, split the result in 3 parts, // and interpolate each of the 3 parts on . -func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polynomial { +func computePermutationPolynomials(trace *Trace, domain *fft.Domain) [3]*iop.Polynomial { nbElmts := int(domain.Cardinality) @@ -377,9 +344,9 @@ func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polyno s2Canonical := make([]fr.Element, nbElmts) s3Canonical := make([]fr.Element, nbElmts) for i := 0; i < nbElmts; i++ { - s1Canonical[i].Set(&evaluationIDSmallDomain[pt.S[i]]) - s2Canonical[i].Set(&evaluationIDSmallDomain[pt.S[nbElmts+i]]) - s3Canonical[i].Set(&evaluationIDSmallDomain[pt.S[2*nbElmts+i]]) + s1Canonical[i].Set(&evaluationIDSmallDomain[trace.S[i]]) + s2Canonical[i].Set(&evaluationIDSmallDomain[trace.S[nbElmts+i]]) + s3Canonical[i].Set(&evaluationIDSmallDomain[trace.S[2*nbElmts+i]]) } lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} diff --git a/backend/plonk/bls24-317/marshal.go b/backend/plonk/bls24-317/marshal.go index 93d125eb7b..21e55f7e75 100644 --- a/backend/plonk/bls24-317/marshal.go +++ b/backend/plonk/bls24-317/marshal.go @@ -19,10 +19,6 @@ package plonk import ( curve "github.com/consensys/gnark-crypto/ecc/bls24-317" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" - - "errors" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/iop" "github.com/consensys/gnark-crypto/ecc/bls24-317/kzg" "io" ) @@ -116,19 +112,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, withCompression bool) (n int64, err e return } - // fft domains - n2, err := pk.Domain[0].WriteTo(w) - if err != nil { - return - } - n += n2 - - n2, err = pk.Domain[1].WriteTo(w) - if err != nil { - return - } - n += n2 - + var n2 int64 // KZG key if withCompression { n2, err = pk.Kzg.WriteTo(w) @@ -149,35 +133,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, withCompression bool) (n int64, err e } n += n2 - // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) - if len(pk.trace.S) != (3 * int(pk.Domain[0].Cardinality)) { - return n, errors.New("invalid permutation size, expected 3*domain cardinality") - } - - enc := curve.NewEncoder(w) - // note: type Polynomial, which is handled by default binary.Write(...) op and doesn't - // encode the size (nor does it convert from Montgomery to Regular form) - // so we explicitly transmit []fr.Element - toEncode := []interface{}{ - pk.trace.Ql.Coefficients(), - pk.trace.Qr.Coefficients(), - pk.trace.Qm.Coefficients(), - pk.trace.Qo.Coefficients(), - pk.trace.Qk.Coefficients(), - coefficients(pk.trace.Qcp), - pk.trace.S1.Coefficients(), - pk.trace.S2.Coefficients(), - pk.trace.S3.Coefficients(), - pk.trace.S, - } - - for _, v := range toEncode { - if err := enc.Encode(v); err != nil { - return n + enc.BytesWritten(), err - } - } - - return n + enc.BytesWritten(), nil + return n, nil } // ReadFrom reads from binary representation in r into ProvingKey @@ -197,18 +153,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err return n, err } - n2, err, chDomain0 := pk.Domain[0].AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - - n2, err, chDomain1 := pk.Domain[1].AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - + var n2 int64 if withSubgroupChecks { n2, err = pk.Kzg.ReadFrom(r) } else { @@ -224,98 +169,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err n2, err = pk.KzgLagrange.UnsafeReadFrom(r) } n += n2 - if err != nil { - return n, err - } - - pk.trace.S = make([]int64, 3*pk.Domain[0].Cardinality) - - dec := curve.NewDecoder(r) - - var ql, qr, qm, qo, qk, s1, s2, s3 []fr.Element - var qcp [][]fr.Element - - // TODO @gbotrel: this is a bit ugly, we should probably refactor this. - // The order of the variables is important, as it matches the order in which they are - // encoded in the WriteTo(...) method. - - // Note: instead of calling dec.Decode(...) for each of the above variables, - // we call AsyncReadFrom when possible which allows to consume bytes from the reader - // and perform the decoding in parallel - - type v struct { - data *fr.Vector - chErr chan error - } - - vectors := make([]v, 8) - vectors[0] = v{data: (*fr.Vector)(&ql)} - vectors[1] = v{data: (*fr.Vector)(&qr)} - vectors[2] = v{data: (*fr.Vector)(&qm)} - vectors[3] = v{data: (*fr.Vector)(&qo)} - vectors[4] = v{data: (*fr.Vector)(&qk)} - vectors[5] = v{data: (*fr.Vector)(&s1)} - vectors[6] = v{data: (*fr.Vector)(&s2)} - vectors[7] = v{data: (*fr.Vector)(&s3)} - - // read ql, qr, qm, qo, qk - for i := 0; i < 5; i++ { - n2, err, ch := vectors[i].data.AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - vectors[i].chErr = ch - } - - // read qcp - if err := dec.Decode(&qcp); err != nil { - return n + dec.BytesRead(), err - } - - // read lqk, s1, s2, s3 - for i := 5; i < 8; i++ { - n2, err, ch := vectors[i].data.AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - vectors[i].chErr = ch - } - - // read pk.Trace.S - if err := dec.Decode(&pk.trace.S); err != nil { - return n + dec.BytesRead(), err - } - - // wait for all AsyncReadFrom(...) to complete - for i := range vectors { - if err := <-vectors[i].chErr; err != nil { - return n, err - } - } - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - pk.trace.Ql = iop.NewPolynomial(&ql, canReg) - pk.trace.Qr = iop.NewPolynomial(&qr, canReg) - pk.trace.Qm = iop.NewPolynomial(&qm, canReg) - pk.trace.Qo = iop.NewPolynomial(&qo, canReg) - pk.trace.Qk = iop.NewPolynomial(&qk, canReg) - pk.trace.S1 = iop.NewPolynomial(&s1, canReg) - pk.trace.S2 = iop.NewPolynomial(&s2, canReg) - pk.trace.S3 = iop.NewPolynomial(&s3, canReg) - - pk.trace.Qcp = make([]*iop.Polynomial, len(qcp)) - for i := range qcp { - pk.trace.Qcp[i] = iop.NewPolynomial(&qcp[i], canReg) - } - - // wait for FFT to be precomputed - <-chDomain0 - <-chDomain1 - - return n + dec.BytesRead(), nil - + return n, err } // WriteTo writes binary encoding of VerifyingKey to w diff --git a/backend/plonk/bls24-317/marshal_test.go b/backend/plonk/bls24-317/marshal_test.go index 54c93b4de9..7e95c5b42e 100644 --- a/backend/plonk/bls24-317/marshal_test.go +++ b/backend/plonk/bls24-317/marshal_test.go @@ -20,9 +20,6 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bls24-317" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" - - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fft" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/iop" "github.com/consensys/gnark/io" "math/big" "math/rand" @@ -60,8 +57,6 @@ func (pk *ProvingKey) randomize() { var vk VerifyingKey vk.randomize() pk.Vk = &vk - pk.Domain[0] = *fft.NewDomain(32) - pk.Domain[1] = *fft.NewDomain(4 * 32) pk.Kzg.G1 = make([]curve.G1Affine, 32) pk.KzgLagrange.G1 = make([]curve.G1Affine, 32) @@ -70,36 +65,6 @@ func (pk *ProvingKey) randomize() { pk.KzgLagrange.G1[i] = randomG1Point() } - n := int(pk.Domain[0].Cardinality) - ql := randomScalars(n) - qr := randomScalars(n) - qm := randomScalars(n) - qo := randomScalars(n) - qk := randomScalars(n) - s1 := randomScalars(n) - s2 := randomScalars(n) - s3 := randomScalars(n) - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - pk.trace.Ql = iop.NewPolynomial(&ql, canReg) - pk.trace.Qr = iop.NewPolynomial(&qr, canReg) - pk.trace.Qm = iop.NewPolynomial(&qm, canReg) - pk.trace.Qo = iop.NewPolynomial(&qo, canReg) - pk.trace.Qk = iop.NewPolynomial(&qk, canReg) - pk.trace.S1 = iop.NewPolynomial(&s1, canReg) - pk.trace.S2 = iop.NewPolynomial(&s2, canReg) - pk.trace.S3 = iop.NewPolynomial(&s3, canReg) - - pk.trace.Qcp = make([]*iop.Polynomial, rand.Intn(4)) //#nosec G404 weak rng is fine here - for i := range pk.trace.Qcp { - qcp := randomScalars(rand.Intn(n / 4)) //#nosec G404 weak rng is fine here - pk.trace.Qcp[i] = iop.NewPolynomial(&qcp, canReg) - } - - pk.trace.S = make([]int64, 3*pk.Domain[0].Cardinality) - pk.trace.S[0] = -12 - pk.trace.S[len(pk.trace.S)-1] = 8888 - } func (vk *VerifyingKey) randomize() { diff --git a/backend/plonk/bls24-317/prove.go b/backend/plonk/bls24-317/prove.go index 8a4ef0e9b6..334178413b 100644 --- a/backend/plonk/bls24-317/prove.go +++ b/backend/plonk/bls24-317/prove.go @@ -138,9 +138,6 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts // solve constraints g.Go(instance.solveConstraints) - // compute numerator data - g.Go(instance.initComputeNumerator) - // complete qk g.Go(instance.completeQk) @@ -211,9 +208,6 @@ type instance struct { // challenges gamma, beta, alpha, zeta fr.Element - // compute numerator data - cres, twiddles0, cosetTableRev, twiddlesRev []fr.Element - // channel to wait for the steps chLRO, chQk, @@ -224,8 +218,11 @@ type instance struct { chZOpening, chLinearizedPolynomial, chFoldedH, - chNumeratorInit, chGammaBeta chan struct{} + + domain0, domain1 *fft.Domain + + trace *Trace } func newInstance(ctx context.Context, spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts *backend.ProverConfig) (*instance, error) { @@ -253,43 +250,31 @@ func newInstance(ctx context.Context, spr *cs.SparseR1CS, pk *ProvingKey, fullWi chLinearizedPolynomial: make(chan struct{}, 1), chFoldedH: make(chan struct{}, 1), chRestoreLRO: make(chan struct{}, 1), - chNumeratorInit: make(chan struct{}, 1), } s.initBSB22Commitments() s.setupGKRHints() s.x = make([]*iop.Polynomial, id_Qci+2*len(s.commitmentInfo)) - return &s, nil -} + // init fft domains + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints + s.domain0 = fft.NewDomain(sizeSystem) -func (s *instance) initComputeNumerator() error { - n := s.pk.Domain[0].Cardinality - s.cres = make([]fr.Element, s.pk.Domain[1].Cardinality) - s.twiddles0 = make([]fr.Element, n) - if n == 1 { - // edge case - s.twiddles0[0].SetOne() + // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, + // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases + // except when n<6. + if sizeSystem < 6 { + s.domain1 = fft.NewDomain(8*sizeSystem, fft.WithoutPrecompute()) } else { - copy(s.twiddles0, s.pk.Domain[0].Twiddles[0]) - for i := len(s.pk.Domain[0].Twiddles[0]); i < len(s.twiddles0); i++ { - s.twiddles0[i].Mul(&s.twiddles0[i-1], &s.twiddles0[1]) - } + s.domain1 = fft.NewDomain(4*sizeSystem, fft.WithoutPrecompute()) } + // TODO @gbotrel domain1 is used for only 1 FFT --> precomputing the twiddles + // and storing them in memory is costly given its size. --> do a FFT on the fly - cosetTable := s.pk.Domain[0].CosetTable - twiddles := s.pk.Domain[1].Twiddles[0][:n] - - s.cosetTableRev = make([]fr.Element, len(cosetTable)) - copy(s.cosetTableRev, cosetTable) - fft.BitReverse(s.cosetTableRev) - - s.twiddlesRev = make([]fr.Element, len(twiddles)) - copy(s.twiddlesRev, twiddles) - fft.BitReverse(s.twiddlesRev) - - close(s.chNumeratorInit) + // build trace + s.trace = NewTrace(spr, s.domain0) - return nil + return &s, nil } func (s *instance) initBlindingPolynomials() error { @@ -321,7 +306,7 @@ func (s *instance) bsb22Hint(_ *big.Int, ins, outs []*big.Int) error { res := &s.commitmentVal[commDepth] commitmentInfo := s.spr.CommitmentInfo.(constraint.PlonkCommitments)[commDepth] - committedValues := make([]fr.Element, s.pk.Domain[0].Cardinality) + committedValues := make([]fr.Element, s.domain0.Cardinality) offset := s.spr.GetNbPublicVariables() for i := range ins { committedValues[offset+commitmentInfo.Committed[i]].SetBigInt(ins[i]) @@ -336,7 +321,6 @@ func (s *instance) bsb22Hint(_ *big.Int, ins, outs []*big.Int) error { if s.proof.Bsb22Commitments[commDepth], err = kzg.Commit(s.cCommitments[commDepth].Coefficients(), s.pk.KzgLagrange); err != nil { return err } - s.cCommitments[commDepth].ToCanonical(&s.pk.Domain[0]).ToRegular() s.htfFunc.Write(s.proof.Bsb22Commitments[commDepth].Marshal()) hashBts := s.htfFunc.Sum(nil) @@ -395,7 +379,7 @@ func (s *instance) solveConstraints() error { } func (s *instance) completeQk() error { - qk := s.pk.trace.Qk.Clone().ToLagrange(&s.pk.Domain[0]).ToRegular() + qk := s.trace.Qk.Clone() qkCoeffs := qk.Coefficients() wWitness, ok := s.fullWitness.Vector().(fr.Vector) @@ -494,7 +478,7 @@ func (s *instance) commitToPolyAndBlinding(p, b *iop.Polynomial) (commit curve.G commit, err = kzg.Commit(p.Coefficients(), s.pk.KzgLagrange) // we add in the blinding contribution - n := int(s.pk.Domain[0].Cardinality) + n := int(s.domain0.Cardinality) cb := commitBlindingFactor(n, b, s.pk.Kzg) commit.Add(&commit, &cb) @@ -518,20 +502,19 @@ func (s *instance) deriveZeta() (err error) { // evaluateConstraints computes H func (s *instance) evaluateConstraints() (err error) { - // clone polys from the proving key. - s.x[id_Ql] = s.pk.trace.Ql.Clone() - s.x[id_Qr] = s.pk.trace.Qr.Clone() - s.x[id_Qm] = s.pk.trace.Qm.Clone() - s.x[id_Qo] = s.pk.trace.Qo.Clone() - s.x[id_S1] = s.pk.trace.S1.Clone() - s.x[id_S2] = s.pk.trace.S2.Clone() - s.x[id_S3] = s.pk.trace.S3.Clone() + s.x[id_Ql] = s.trace.Ql + s.x[id_Qr] = s.trace.Qr + s.x[id_Qm] = s.trace.Qm + s.x[id_Qo] = s.trace.Qo + s.x[id_S1] = s.trace.S1 + s.x[id_S2] = s.trace.S2 + s.x[id_S3] = s.trace.S3 for i := 0; i < len(s.commitmentInfo); i++ { - s.x[id_Qci+2*i] = s.pk.trace.Qcp[i].Clone() + s.x[id_Qci+2*i] = s.trace.Qcp[i] } - n := s.pk.Domain[0].Cardinality + n := s.domain0.Cardinality lone := make([]fr.Element, n) lone[0].SetOne() @@ -543,7 +526,7 @@ func (s *instance) evaluateConstraints() (err error) { } for i := 0; i < len(s.commitmentInfo); i++ { - s.x[id_Qci+2*i+1] = s.cCommitments[i].Clone() + s.x[id_Qci+2*i+1] = s.cCommitments[i] } // wait for Z to be committed or context done @@ -571,7 +554,7 @@ func (s *instance) evaluateConstraints() (err error) { return err } - s.h, err = divideByXMinusOne(numerator, [2]*fft.Domain{&s.pk.Domain[0], &s.pk.Domain[1]}) + s.h, err = divideByXMinusOne(numerator, [2]*fft.Domain{s.domain0, s.domain1}) if err != nil { return err } @@ -613,11 +596,11 @@ func (s *instance) buildRatioCopyConstraint() (err error) { s.x[id_R], s.x[id_O], }, - s.pk.trace.S, + s.trace.S, s.beta, s.gamma, iop.Form{Basis: iop.Lagrange, Layout: iop.Regular}, - &s.pk.Domain[0], + s.domain0, ) if err != nil { return err @@ -652,17 +635,17 @@ func (s *instance) openZ() (err error) { } func (s *instance) h1() []fr.Element { - h1 := s.h.Coefficients()[:s.pk.Domain[0].Cardinality+2] + h1 := s.h.Coefficients()[:s.domain0.Cardinality+2] return h1 } func (s *instance) h2() []fr.Element { - h2 := s.h.Coefficients()[s.pk.Domain[0].Cardinality+2 : 2*(s.pk.Domain[0].Cardinality+2)] + h2 := s.h.Coefficients()[s.domain0.Cardinality+2 : 2*(s.domain0.Cardinality+2)] return h2 } func (s *instance) h3() []fr.Element { - h3 := s.h.Coefficients()[2*(s.pk.Domain[0].Cardinality+2) : 3*(s.pk.Domain[0].Cardinality+2)] + h3 := s.h.Coefficients()[2*(s.domain0.Cardinality+2) : 3*(s.domain0.Cardinality+2)] return h3 } @@ -675,7 +658,7 @@ func (s *instance) foldH() error { case <-s.chH: } var n big.Int - n.SetUint64(s.pk.Domain[0].Cardinality + 2) + n.SetUint64(s.domain0.Cardinality + 2) var zetaPowerNplusTwo fr.Element zetaPowerNplusTwo.Exp(s.zeta, &n) @@ -691,7 +674,7 @@ func (s *instance) foldH() error { h2 := s.h2() s.foldedH = s.h3() - for i := 0; i < int(s.pk.Domain[0].Cardinality)+2; i++ { + for i := 0; i < int(s.domain0.Cardinality)+2; i++ { s.foldedH[i]. Mul(&s.foldedH[i], &zetaPowerNplusTwo). Add(&s.foldedH[i], &h2[i]). @@ -720,7 +703,7 @@ func (s *instance) computeLinearizedPolynomial() error { for i := 0; i < len(s.commitmentInfo); i++ { go func(i int) { - qcpzeta[i] = s.pk.trace.Qcp[i].Evaluate(s.zeta) + qcpzeta[i] = s.trace.Qcp[i].Evaluate(s.zeta) wg.Done() }(i) } @@ -750,7 +733,7 @@ func (s *instance) computeLinearizedPolynomial() error { wg.Wait() - s.linearizedPolynomial = computeLinearizedPolynomial( + s.linearizedPolynomial = s.innerComputeLinearizedPoly( blzeta, brzeta, bozeta, @@ -775,9 +758,6 @@ func (s *instance) computeLinearizedPolynomial() error { } func (s *instance) batchOpening() error { - polysQcp := coefficients(s.pk.trace.Qcp) - polysToOpen := make([][]fr.Element, 7+len(polysQcp)) - copy(polysToOpen[7:], polysQcp) // wait for LRO to be committed (or ctx.Done()) select { @@ -800,13 +780,17 @@ func (s *instance) batchOpening() error { case <-s.chLinearizedPolynomial: } + polysQcp := coefficients(s.trace.Qcp) + polysToOpen := make([][]fr.Element, 7+len(polysQcp)) + copy(polysToOpen[7:], polysQcp) + polysToOpen[0] = s.foldedH polysToOpen[1] = s.linearizedPolynomial polysToOpen[2] = getBlindedCoefficients(s.x[id_L], s.bp[id_Bl]) polysToOpen[3] = getBlindedCoefficients(s.x[id_R], s.bp[id_Br]) polysToOpen[4] = getBlindedCoefficients(s.x[id_O], s.bp[id_Bo]) - polysToOpen[5] = s.pk.trace.S1.Coefficients() - polysToOpen[6] = s.pk.trace.S2.Coefficients() + polysToOpen[5] = s.trace.S1.Coefficients() + polysToOpen[6] = s.trace.S2.Coefficients() digestsToOpen := make([]curve.G1Affine, len(s.pk.Vk.Qcp)+7) copy(digestsToOpen[7:], s.pk.Vk.Qcp) @@ -835,6 +819,23 @@ func (s *instance) batchOpening() error { // evaluate the full set of constraints, all polynomials in x are back in // canonical regular form at the end func (s *instance) computeNumerator() (*iop.Polynomial, error) { + // init vectors that are used multiple times throughout the computation + n := s.domain0.Cardinality + twiddles0 := make([]fr.Element, n) + if n == 1 { + // edge case + twiddles0[0].SetOne() + } else { + twiddles, err := s.domain0.Twiddles() + if err != nil { + return nil, err + } + copy(twiddles0, twiddles[0]) + w := twiddles0[1] + for i := len(twiddles[0]); i < len(twiddles0); i++ { + twiddles0[i].Mul(&twiddles0[i-1], &w) + } + } // wait for chQk to be closed (or ctx.Done()) select { @@ -843,8 +844,6 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { case <-s.chQk: } - n := s.pk.Domain[0].Cardinality - nbBsbGates := (len(s.x) - id_Qci + 1) >> 1 gateConstraint := func(u ...fr.Element) fr.Element { @@ -867,7 +866,7 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { } var cs, css fr.Element - cs.Set(&s.pk.Domain[1].FrMultiplicativeGen) + cs.Set(&s.domain1.FrMultiplicativeGen) css.Square(&cs) orderingConstraint := func(u ...fr.Element) fr.Element { @@ -899,11 +898,11 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { return res } - rho := int(s.pk.Domain[1].Cardinality / n) + rho := int(s.domain1.Cardinality / n) shifters := make([]fr.Element, rho) - shifters[0].Set(&s.pk.Domain[1].FrMultiplicativeGen) + shifters[0].Set(&s.domain1.FrMultiplicativeGen) for i := 1; i < rho; i++ { - shifters[i].Set(&s.pk.Domain[1].Generator) + shifters[i].Set(&s.domain1.Generator) } // stores the current coset shifter @@ -914,14 +913,13 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { one.SetOne() bn := big.NewInt(int64(n)) - // wait for init go routine - <-s.chNumeratorInit - - cosetTable := s.pk.Domain[0].CosetTable - twiddles := s.pk.Domain[1].Twiddles[0][:n] + cosetTable, err := s.domain0.CosetTable() + if err != nil { + return nil, err + } // init the result polynomial & buffer - cres := s.cres + cres := make([]fr.Element, s.domain1.Cardinality) buf := make([]fr.Element, n) var wgBuf sync.WaitGroup @@ -933,17 +931,17 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { // blind L, R, O, Z, ZS var y fr.Element - y = s.bp[id_Bl].Evaluate(s.twiddles0[i]) + y = s.bp[id_Bl].Evaluate(twiddles0[i]) u[id_L].Add(&u[id_L], &y) - y = s.bp[id_Br].Evaluate(s.twiddles0[i]) + y = s.bp[id_Br].Evaluate(twiddles0[i]) u[id_R].Add(&u[id_R], &y) - y = s.bp[id_Bo].Evaluate(s.twiddles0[i]) + y = s.bp[id_Bo].Evaluate(twiddles0[i]) u[id_O].Add(&u[id_O], &y) - y = s.bp[id_Bz].Evaluate(s.twiddles0[i]) + y = s.bp[id_Bz].Evaluate(twiddles0[i]) u[id_Z].Add(&u[id_Z], &y) // ZS is shifted by 1; need to get correct twiddle - y = s.bp[id_Bz].Evaluate(s.twiddles0[(i+1)%int(n)]) + y = s.bp[id_Bz].Evaluate(twiddles0[(i+1)%int(n)]) u[id_ZS].Add(&u[id_ZS], &y) a := gateConstraint(u...) @@ -953,28 +951,15 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { return c } - // select the correct scaling vector to scale by shifter[i] - selectScalingVector := func(i int, l iop.Layout) []fr.Element { - var w []fr.Element - if i == 0 { - if l == iop.Regular { - w = cosetTable - } else { - w = s.cosetTableRev - } - } else { - if l == iop.Regular { - w = twiddles - } else { - w = s.twiddlesRev - } - } - return w - } + // for the first iteration, the scalingVector is the coset table + scalingVector := cosetTable + scalingVectorRev := make([]fr.Element, len(cosetTable)) + copy(scalingVectorRev, cosetTable) + fft.BitReverse(scalingVectorRev) // pre-computed to compute the bit reverse index // of the result polynomial - m := uint64(s.pk.Domain[1].Cardinality) + m := uint64(s.domain1.Cardinality) mm := uint64(64 - bits.TrailingZeros64(m)) for i := 0; i < rho; i++ { @@ -991,6 +976,17 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { acc.Mul(&acc, &shifters[i]) } } + if i == 1 { + // we have to update the scalingVector; instead of scaling by + // cosets we scale by the twiddles of the large domain. + w := s.domain1.Generator + scalingVector = make([]fr.Element, n) + fft.BuildExpTable(w, scalingVector) + + // reuse memory + copy(scalingVectorRev, scalingVector) + fft.BitReverse(scalingVectorRev) + } // we do **a lot** of FFT here, but on the small domain. // note that for all the polynomials in the proving key @@ -1000,10 +996,15 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { batchApply(s.x, func(p *iop.Polynomial) { nbTasks := calculateNbTasks(len(s.x)-1) * 2 // shift polynomials to be in the correct coset - p.ToCanonical(&s.pk.Domain[0], nbTasks) + p.ToCanonical(s.domain0, nbTasks) // scale by shifter[i] - w := selectScalingVector(i, p.Layout) + var w []fr.Element + if p.Layout == iop.Regular { + w = scalingVector + } else { + w = scalingVectorRev + } cp := p.Coefficients() utils.Parallelize(len(cp), func(start, end int) { @@ -1013,7 +1014,7 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { }, nbTasks) // fft in the correct coset - p.ToLagrange(&s.pk.Domain[0], nbTasks).ToRegular() + p.ToLagrange(s.domain0, nbTasks).ToRegular() }) wgBuf.Wait() @@ -1046,9 +1047,10 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { // scale everything back go func() { - for i := id_ZS; i < len(s.x); i++ { - s.x[i] = nil - } + s.x[id_ID] = nil + s.x[id_LOne] = nil + s.x[id_ZS] = nil + s.x[id_Qk] = nil var cs fr.Element cs.Set(&shifters[0]) @@ -1057,8 +1059,11 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { } cs.Inverse(&cs) - batchApply(s.x[:id_ZS], func(p *iop.Polynomial) { - p.ToCanonical(&s.pk.Domain[0], 8).ToRegular() + batchApply(s.x, func(p *iop.Polynomial) { + if p == nil { + return + } + p.ToCanonical(s.domain0, 8).ToRegular() scalePowers(p, cs) }) @@ -1275,7 +1280,7 @@ func evaluateXnMinusOneDomainBigCoset(domains [2]*fft.Domain) []fr.Element { return res } -// computeLinearizedPolynomial computes the linearized polynomial in canonical basis. +// innerComputeLinearizedPoly computes the linearized polynomial in canonical basis. // The purpose is to commit and open all in one ql, qr, qm, qo, qk. // * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta // * z is the permutation polynomial, zu is Z(μX), the shifted version of Z @@ -1286,8 +1291,8 @@ func evaluateXnMinusOneDomainBigCoset(domains [2]*fft.Domain) []fr.Element { // α²*L₁(ζ)*Z(X) // + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) // + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) -func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, qcpZeta, blindedZCanonical []fr.Element, pi2Canonical [][]fr.Element, pk *ProvingKey) []fr.Element { - +func (s *instance) innerComputeLinearizedPoly(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, qcpZeta, blindedZCanonical []fr.Element, pi2Canonical [][]fr.Element, pk *ProvingKey) []fr.Element { + // TODO @gbotrel rename // first part: individual constraints var rl fr.Element rl.Mul(&rZeta, &lZeta) @@ -1297,12 +1302,12 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, var s1, s2 fr.Element chS1 := make(chan struct{}, 1) go func() { - s1 = pk.trace.S1.Evaluate(zeta) // s1(ζ) + s1 = s.trace.S1.Evaluate(zeta) // s1(ζ) s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) close(chS1) }() // ps2 := iop.NewPolynomial(&pk.S2Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - tmp := pk.trace.S2.Evaluate(zeta) // s2(ζ) + tmp := s.trace.S2.Evaluate(zeta) // s2(ζ) tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) @@ -1321,7 +1326,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, // third part L₁(ζ)*α²*Z var lagrangeZeta, one, den, frNbElmt fr.Element one.SetOne() - nbElmt := int64(pk.Domain[0].Cardinality) + nbElmt := int64(s.domain0.Cardinality) lagrangeZeta.Set(&zeta). Exp(lagrangeZeta, big.NewInt(nbElmt)). Sub(&lagrangeZeta, &one) @@ -1331,17 +1336,19 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) Mul(&lagrangeZeta, &alpha). Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) + Mul(&lagrangeZeta, &s.domain0.CardinalityInv) // (1/n)*α²*L₁(ζ) + + s3canonical := s.trace.S3.Coefficients() - s3canonical := pk.trace.S3.Coefficients() + s.trace.Qk.ToCanonical(s.domain0).ToRegular() utils.Parallelize(len(blindedZCanonical), func(start, end int) { - cql := pk.trace.Ql.Coefficients() - cqr := pk.trace.Qr.Coefficients() - cqm := pk.trace.Qm.Coefficients() - cqo := pk.trace.Qo.Coefficients() - cqk := pk.trace.Qk.Coefficients() + cql := s.trace.Ql.Coefficients() + cqr := s.trace.Qr.Coefficients() + cqm := s.trace.Qm.Coefficients() + cqo := s.trace.Qo.Coefficients() + cqk := s.trace.Qk.Coefficients() var t, t0, t1 fr.Element diff --git a/backend/plonk/bls24-317/setup.go b/backend/plonk/bls24-317/setup.go index 6ebb0f8ee3..1359bb5097 100644 --- a/backend/plonk/bls24-317/setup.go +++ b/backend/plonk/bls24-317/setup.go @@ -78,34 +78,14 @@ type Trace struct { S []int64 } -// ProvingKey stores the data needed to generate a proof: -// * the commitment scheme -// * ql, prepended with as many ones as they are public inputs -// * qr, qm, qo prepended with as many zeroes as there are public inputs. -// * qk, prepended with as many zeroes as public inputs, to be completed by the prover -// with the list of public inputs. -// * sigma_1, sigma_2, sigma_3 in both basis -// * the copy constraint permutation +// ProvingKey stores the data needed to generate a proof type ProvingKey struct { - // stores ql, qr, qm, qo, qk (-> to be completed by the prover) - // and s1, s2, s3. They are set in canonical basis before generating the proof, they will be used - // for computing the opening proofs (hence the canonical form). The canonical version - // of qk incomplete is used in the linearisation polynomial. - // The polynomials in trace are in canonical basis. - trace Trace - Kzg, KzgLagrange kzg.ProvingKey // Verifying Key is embedded into the proving key (needed by Prove) Vk *VerifyingKey - - // Domains used for the FFTs. - // Domain[0] = small Domain - // Domain[1] = big Domain - Domain [2]fft.Domain } -// TODO modify the signature to receive the SRS in Lagrange form (optional argument ?) func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *VerifyingKey, error) { var pk ProvingKey @@ -114,26 +94,26 @@ func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *Verifyin vk.CommitmentConstraintIndexes = internal.IntSliceToUint64Slice(spr.CommitmentInfo.CommitmentIndexes()) // step 0: set the fft domains - pk.initDomains(spr) - if pk.Domain[0].Cardinality < 2 { + domain := initFFTDomain(spr) + if domain.Cardinality < 2 { return nil, nil, fmt.Errorf("circuit has only %d constraints; unsupported by the current implementation", spr.GetNbConstraints()) } // check the size of the kzg srs. - if len(srs.Pk.G1) < (int(pk.Domain[0].Cardinality) + 3) { // + 3 for the kzg.Open of blinded poly - return nil, nil, fmt.Errorf("kzg srs is too small: got %d, need %d", len(srs.Pk.G1), pk.Domain[0].Cardinality+3) + if len(srs.Pk.G1) < (int(domain.Cardinality) + 3) { // + 3 for the kzg.Open of blinded poly + return nil, nil, fmt.Errorf("kzg srs is too small: got %d, need %d", len(srs.Pk.G1), domain.Cardinality+3) } // same for the lagrange form - if len(srsLagrange.Pk.G1) != int(pk.Domain[0].Cardinality) { - return nil, nil, fmt.Errorf("kzg srs lagrange is too small: got %d, need %d", len(srsLagrange.Pk.G1), pk.Domain[0].Cardinality) + if len(srsLagrange.Pk.G1) != int(domain.Cardinality) { + return nil, nil, fmt.Errorf("kzg srs lagrange is too small: got %d, need %d", len(srsLagrange.Pk.G1), domain.Cardinality) } // step 1: set the verifying key - pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) - vk.Size = pk.Domain[0].Cardinality + vk.CosetShift.Set(&domain.FrMultiplicativeGen) + vk.Size = domain.Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.Domain[0].Generator) + vk.Generator.Set(&domain.Generator) vk.NbPublicVariables = uint64(len(spr.Public)) pk.Kzg.G1 = srs.Pk.G1[:int(vk.Size)+3] @@ -141,22 +121,15 @@ func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *Verifyin vk.Kzg = srs.Vk // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis - BuildTrace(spr, &pk.trace) - // step 3: build the permutation and build the polynomials S1, S2, S3 to encode the permutation. // Note: at this stage, the permutation takes in account the placeholders - nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - buildPermutation(spr, &pk.trace, nbVariables) - s := computePermutationPolynomials(&pk.trace, &pk.Domain[0]) - pk.trace.S1 = s[0] - pk.trace.S2 = s[1] - pk.trace.S3 = s[2] + trace := NewTrace(spr, domain) // step 4: commit to s1, s2, s3, ql, qr, qm, qo, and (the incomplete version of) qk. // All the above polynomials are expressed in canonical basis afterwards. This is why // we save lqk before, because the prover needs to complete it in Lagrange form, and // then express it on the Lagrange coset basis. - if err := commitTrace(&pk.trace, &pk); err != nil { + if err := vk.commitTrace(trace, domain, pk.KzgLagrange); err != nil { return nil, nil, err } @@ -173,9 +146,13 @@ func (pk *ProvingKey) VerifyingKey() interface{} { return pk.Vk } -// BuildTrace fills the constant columns ql, qr, qm, qo, qk from the sparser1cs. -// Size is the size of the system that is nb_constraints+nb_public_variables -func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { +// NewTrace returns a new Trace object from the constraint system. +// It fills the constant columns ql, qr, qm, qo, qk, and qcp with the +// coefficients of the constraints. +// Size is the size of the system that is next power of 2 (nb_constraints+nb_public_variables) +// The permutation is also computed and stored in the Trace. +func NewTrace(spr *cs.SparseR1CS, domain *fft.Domain) *Trace { + var trace Trace nbConstraints := spr.GetNbConstraints() sizeSystem := uint64(nbConstraints + len(spr.Public)) @@ -211,85 +188,75 @@ func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} - pt.Ql = iop.NewPolynomial(&ql, lagReg) - pt.Qr = iop.NewPolynomial(&qr, lagReg) - pt.Qm = iop.NewPolynomial(&qm, lagReg) - pt.Qo = iop.NewPolynomial(&qo, lagReg) - pt.Qk = iop.NewPolynomial(&qk, lagReg) - pt.Qcp = make([]*iop.Polynomial, len(qcp)) + trace.Ql = iop.NewPolynomial(&ql, lagReg) + trace.Qr = iop.NewPolynomial(&qr, lagReg) + trace.Qm = iop.NewPolynomial(&qm, lagReg) + trace.Qo = iop.NewPolynomial(&qo, lagReg) + trace.Qk = iop.NewPolynomial(&qk, lagReg) + trace.Qcp = make([]*iop.Polynomial, len(qcp)) for i := range commitmentInfo { qcp[i] = make([]fr.Element, size) for _, committed := range commitmentInfo[i].Committed { qcp[i][offset+committed].SetOne() } - pt.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) + trace.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) } + + // build the permutation and build the polynomials S1, S2, S3 to encode the permutation. + // Note: at this stage, the permutation takes in account the placeholders + nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + buildPermutation(spr, &trace, nbVariables) + s := computePermutationPolynomials(&trace, domain) + trace.S1 = s[0] + trace.S2 = s[1] + trace.S3 = s[2] + + return &trace } // commitTrace commits to every polynomial in the trace, and put // the commitments int the verifying key. -func commitTrace(trace *Trace, pk *ProvingKey) error { - - trace.Ql.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qr.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qm.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qo.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qk.ToCanonical(&pk.Domain[0]).ToRegular() // -> qk is not complete - trace.S1.ToCanonical(&pk.Domain[0]).ToRegular() - trace.S2.ToCanonical(&pk.Domain[0]).ToRegular() - trace.S3.ToCanonical(&pk.Domain[0]).ToRegular() +func (vk *VerifyingKey) commitTrace(trace *Trace, domain *fft.Domain, srsPk kzg.ProvingKey) error { var err error - pk.Vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) + vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) for i := range trace.Qcp { - trace.Qcp[i].ToCanonical(&pk.Domain[0]).ToRegular() - if pk.Vk.Qcp[i], err = kzg.Commit(pk.trace.Qcp[i].Coefficients(), pk.Kzg); err != nil { + if vk.Qcp[i], err = kzg.Commit(trace.Qcp[i].Coefficients(), srsPk); err != nil { return err } } - if pk.Vk.Ql, err = kzg.Commit(pk.trace.Ql.Coefficients(), pk.Kzg); err != nil { + if vk.Ql, err = kzg.Commit(trace.Ql.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qr, err = kzg.Commit(pk.trace.Qr.Coefficients(), pk.Kzg); err != nil { + if vk.Qr, err = kzg.Commit(trace.Qr.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qm, err = kzg.Commit(pk.trace.Qm.Coefficients(), pk.Kzg); err != nil { + if vk.Qm, err = kzg.Commit(trace.Qm.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qo, err = kzg.Commit(pk.trace.Qo.Coefficients(), pk.Kzg); err != nil { + if vk.Qo, err = kzg.Commit(trace.Qo.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qk, err = kzg.Commit(pk.trace.Qk.Coefficients(), pk.Kzg); err != nil { + if vk.Qk, err = kzg.Commit(trace.Qk.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.S[0], err = kzg.Commit(pk.trace.S1.Coefficients(), pk.Kzg); err != nil { + if vk.S[0], err = kzg.Commit(trace.S1.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.S[1], err = kzg.Commit(pk.trace.S2.Coefficients(), pk.Kzg); err != nil { + if vk.S[1], err = kzg.Commit(trace.S2.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.S[2], err = kzg.Commit(pk.trace.S3.Coefficients(), pk.Kzg); err != nil { + if vk.S[2], err = kzg.Commit(trace.S3.Coefficients(), srsPk); err != nil { return err } return nil } -func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { - +func initFFTDomain(spr *cs.SparseR1CS) *fft.Domain { nbConstraints := spr.GetNbConstraints() sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints - pk.Domain[0] = *fft.NewDomain(sizeSystem) - - // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, - // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases - // except when n<6. - if sizeSystem < 6 { - pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) - } else { - pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) - } - + return fft.NewDomain(sizeSystem, fft.WithoutPrecompute()) } // buildPermutation builds the Permutation associated with a circuit. @@ -304,10 +271,10 @@ func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { // The permutation is encoded as a slice s of size 3*size(l), where the // i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab // like this: for i in tab: tab[i] = tab[permutation[i]] -func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { +func buildPermutation(spr *cs.SparseR1CS, trace *Trace, nbVariables int) { // nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - sizeSolution := len(pt.Ql.Coefficients()) + sizeSolution := len(trace.Ql.Coefficients()) sizePermutation := 3 * sizeSolution // init permutation @@ -357,13 +324,13 @@ func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { } } - pt.S = permutation + trace.S = permutation } // computePermutationPolynomials computes the LDE (Lagrange basis) of the permutation. // We let the permutation act on || u || u^{2}, split the result in 3 parts, // and interpolate each of the 3 parts on . -func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polynomial { +func computePermutationPolynomials(trace *Trace, domain *fft.Domain) [3]*iop.Polynomial { nbElmts := int(domain.Cardinality) @@ -377,9 +344,9 @@ func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polyno s2Canonical := make([]fr.Element, nbElmts) s3Canonical := make([]fr.Element, nbElmts) for i := 0; i < nbElmts; i++ { - s1Canonical[i].Set(&evaluationIDSmallDomain[pt.S[i]]) - s2Canonical[i].Set(&evaluationIDSmallDomain[pt.S[nbElmts+i]]) - s3Canonical[i].Set(&evaluationIDSmallDomain[pt.S[2*nbElmts+i]]) + s1Canonical[i].Set(&evaluationIDSmallDomain[trace.S[i]]) + s2Canonical[i].Set(&evaluationIDSmallDomain[trace.S[nbElmts+i]]) + s3Canonical[i].Set(&evaluationIDSmallDomain[trace.S[2*nbElmts+i]]) } lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} diff --git a/backend/plonk/bn254/marshal.go b/backend/plonk/bn254/marshal.go index dbf2843552..4f128ae2ab 100644 --- a/backend/plonk/bn254/marshal.go +++ b/backend/plonk/bn254/marshal.go @@ -19,10 +19,6 @@ package plonk import ( curve "github.com/consensys/gnark-crypto/ecc/bn254" - "github.com/consensys/gnark-crypto/ecc/bn254/fr" - - "errors" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/iop" "github.com/consensys/gnark-crypto/ecc/bn254/kzg" "io" ) @@ -116,19 +112,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, withCompression bool) (n int64, err e return } - // fft domains - n2, err := pk.Domain[0].WriteTo(w) - if err != nil { - return - } - n += n2 - - n2, err = pk.Domain[1].WriteTo(w) - if err != nil { - return - } - n += n2 - + var n2 int64 // KZG key if withCompression { n2, err = pk.Kzg.WriteTo(w) @@ -149,35 +133,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, withCompression bool) (n int64, err e } n += n2 - // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) - if len(pk.trace.S) != (3 * int(pk.Domain[0].Cardinality)) { - return n, errors.New("invalid permutation size, expected 3*domain cardinality") - } - - enc := curve.NewEncoder(w) - // note: type Polynomial, which is handled by default binary.Write(...) op and doesn't - // encode the size (nor does it convert from Montgomery to Regular form) - // so we explicitly transmit []fr.Element - toEncode := []interface{}{ - pk.trace.Ql.Coefficients(), - pk.trace.Qr.Coefficients(), - pk.trace.Qm.Coefficients(), - pk.trace.Qo.Coefficients(), - pk.trace.Qk.Coefficients(), - coefficients(pk.trace.Qcp), - pk.trace.S1.Coefficients(), - pk.trace.S2.Coefficients(), - pk.trace.S3.Coefficients(), - pk.trace.S, - } - - for _, v := range toEncode { - if err := enc.Encode(v); err != nil { - return n + enc.BytesWritten(), err - } - } - - return n + enc.BytesWritten(), nil + return n, nil } // ReadFrom reads from binary representation in r into ProvingKey @@ -197,18 +153,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err return n, err } - n2, err, chDomain0 := pk.Domain[0].AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - - n2, err, chDomain1 := pk.Domain[1].AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - + var n2 int64 if withSubgroupChecks { n2, err = pk.Kzg.ReadFrom(r) } else { @@ -224,98 +169,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err n2, err = pk.KzgLagrange.UnsafeReadFrom(r) } n += n2 - if err != nil { - return n, err - } - - pk.trace.S = make([]int64, 3*pk.Domain[0].Cardinality) - - dec := curve.NewDecoder(r) - - var ql, qr, qm, qo, qk, s1, s2, s3 []fr.Element - var qcp [][]fr.Element - - // TODO @gbotrel: this is a bit ugly, we should probably refactor this. - // The order of the variables is important, as it matches the order in which they are - // encoded in the WriteTo(...) method. - - // Note: instead of calling dec.Decode(...) for each of the above variables, - // we call AsyncReadFrom when possible which allows to consume bytes from the reader - // and perform the decoding in parallel - - type v struct { - data *fr.Vector - chErr chan error - } - - vectors := make([]v, 8) - vectors[0] = v{data: (*fr.Vector)(&ql)} - vectors[1] = v{data: (*fr.Vector)(&qr)} - vectors[2] = v{data: (*fr.Vector)(&qm)} - vectors[3] = v{data: (*fr.Vector)(&qo)} - vectors[4] = v{data: (*fr.Vector)(&qk)} - vectors[5] = v{data: (*fr.Vector)(&s1)} - vectors[6] = v{data: (*fr.Vector)(&s2)} - vectors[7] = v{data: (*fr.Vector)(&s3)} - - // read ql, qr, qm, qo, qk - for i := 0; i < 5; i++ { - n2, err, ch := vectors[i].data.AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - vectors[i].chErr = ch - } - - // read qcp - if err := dec.Decode(&qcp); err != nil { - return n + dec.BytesRead(), err - } - - // read lqk, s1, s2, s3 - for i := 5; i < 8; i++ { - n2, err, ch := vectors[i].data.AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - vectors[i].chErr = ch - } - - // read pk.Trace.S - if err := dec.Decode(&pk.trace.S); err != nil { - return n + dec.BytesRead(), err - } - - // wait for all AsyncReadFrom(...) to complete - for i := range vectors { - if err := <-vectors[i].chErr; err != nil { - return n, err - } - } - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - pk.trace.Ql = iop.NewPolynomial(&ql, canReg) - pk.trace.Qr = iop.NewPolynomial(&qr, canReg) - pk.trace.Qm = iop.NewPolynomial(&qm, canReg) - pk.trace.Qo = iop.NewPolynomial(&qo, canReg) - pk.trace.Qk = iop.NewPolynomial(&qk, canReg) - pk.trace.S1 = iop.NewPolynomial(&s1, canReg) - pk.trace.S2 = iop.NewPolynomial(&s2, canReg) - pk.trace.S3 = iop.NewPolynomial(&s3, canReg) - - pk.trace.Qcp = make([]*iop.Polynomial, len(qcp)) - for i := range qcp { - pk.trace.Qcp[i] = iop.NewPolynomial(&qcp[i], canReg) - } - - // wait for FFT to be precomputed - <-chDomain0 - <-chDomain1 - - return n + dec.BytesRead(), nil - + return n, err } // WriteTo writes binary encoding of VerifyingKey to w diff --git a/backend/plonk/bn254/marshal_test.go b/backend/plonk/bn254/marshal_test.go index a2ab848ced..2a193872cf 100644 --- a/backend/plonk/bn254/marshal_test.go +++ b/backend/plonk/bn254/marshal_test.go @@ -20,9 +20,6 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bn254" "github.com/consensys/gnark-crypto/ecc/bn254/fr" - - "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/iop" "github.com/consensys/gnark/io" "math/big" "math/rand" @@ -60,8 +57,6 @@ func (pk *ProvingKey) randomize() { var vk VerifyingKey vk.randomize() pk.Vk = &vk - pk.Domain[0] = *fft.NewDomain(32) - pk.Domain[1] = *fft.NewDomain(4 * 32) pk.Kzg.G1 = make([]curve.G1Affine, 32) pk.KzgLagrange.G1 = make([]curve.G1Affine, 32) @@ -70,36 +65,6 @@ func (pk *ProvingKey) randomize() { pk.KzgLagrange.G1[i] = randomG1Point() } - n := int(pk.Domain[0].Cardinality) - ql := randomScalars(n) - qr := randomScalars(n) - qm := randomScalars(n) - qo := randomScalars(n) - qk := randomScalars(n) - s1 := randomScalars(n) - s2 := randomScalars(n) - s3 := randomScalars(n) - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - pk.trace.Ql = iop.NewPolynomial(&ql, canReg) - pk.trace.Qr = iop.NewPolynomial(&qr, canReg) - pk.trace.Qm = iop.NewPolynomial(&qm, canReg) - pk.trace.Qo = iop.NewPolynomial(&qo, canReg) - pk.trace.Qk = iop.NewPolynomial(&qk, canReg) - pk.trace.S1 = iop.NewPolynomial(&s1, canReg) - pk.trace.S2 = iop.NewPolynomial(&s2, canReg) - pk.trace.S3 = iop.NewPolynomial(&s3, canReg) - - pk.trace.Qcp = make([]*iop.Polynomial, rand.Intn(4)) //#nosec G404 weak rng is fine here - for i := range pk.trace.Qcp { - qcp := randomScalars(rand.Intn(n / 4)) //#nosec G404 weak rng is fine here - pk.trace.Qcp[i] = iop.NewPolynomial(&qcp, canReg) - } - - pk.trace.S = make([]int64, 3*pk.Domain[0].Cardinality) - pk.trace.S[0] = -12 - pk.trace.S[len(pk.trace.S)-1] = 8888 - } func (vk *VerifyingKey) randomize() { diff --git a/backend/plonk/bn254/prove.go b/backend/plonk/bn254/prove.go index f59d7b3deb..2bb6ea33e3 100644 --- a/backend/plonk/bn254/prove.go +++ b/backend/plonk/bn254/prove.go @@ -138,9 +138,6 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts // solve constraints g.Go(instance.solveConstraints) - // compute numerator data - g.Go(instance.initComputeNumerator) - // complete qk g.Go(instance.completeQk) @@ -211,9 +208,6 @@ type instance struct { // challenges gamma, beta, alpha, zeta fr.Element - // compute numerator data - cres, twiddles0, cosetTableRev, twiddlesRev []fr.Element - // channel to wait for the steps chLRO, chQk, @@ -224,8 +218,11 @@ type instance struct { chZOpening, chLinearizedPolynomial, chFoldedH, - chNumeratorInit, chGammaBeta chan struct{} + + domain0, domain1 *fft.Domain + + trace *Trace } func newInstance(ctx context.Context, spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts *backend.ProverConfig) (*instance, error) { @@ -253,43 +250,31 @@ func newInstance(ctx context.Context, spr *cs.SparseR1CS, pk *ProvingKey, fullWi chLinearizedPolynomial: make(chan struct{}, 1), chFoldedH: make(chan struct{}, 1), chRestoreLRO: make(chan struct{}, 1), - chNumeratorInit: make(chan struct{}, 1), } s.initBSB22Commitments() s.setupGKRHints() s.x = make([]*iop.Polynomial, id_Qci+2*len(s.commitmentInfo)) - return &s, nil -} + // init fft domains + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints + s.domain0 = fft.NewDomain(sizeSystem) -func (s *instance) initComputeNumerator() error { - n := s.pk.Domain[0].Cardinality - s.cres = make([]fr.Element, s.pk.Domain[1].Cardinality) - s.twiddles0 = make([]fr.Element, n) - if n == 1 { - // edge case - s.twiddles0[0].SetOne() + // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, + // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases + // except when n<6. + if sizeSystem < 6 { + s.domain1 = fft.NewDomain(8*sizeSystem, fft.WithoutPrecompute()) } else { - copy(s.twiddles0, s.pk.Domain[0].Twiddles[0]) - for i := len(s.pk.Domain[0].Twiddles[0]); i < len(s.twiddles0); i++ { - s.twiddles0[i].Mul(&s.twiddles0[i-1], &s.twiddles0[1]) - } + s.domain1 = fft.NewDomain(4*sizeSystem, fft.WithoutPrecompute()) } + // TODO @gbotrel domain1 is used for only 1 FFT --> precomputing the twiddles + // and storing them in memory is costly given its size. --> do a FFT on the fly - cosetTable := s.pk.Domain[0].CosetTable - twiddles := s.pk.Domain[1].Twiddles[0][:n] - - s.cosetTableRev = make([]fr.Element, len(cosetTable)) - copy(s.cosetTableRev, cosetTable) - fft.BitReverse(s.cosetTableRev) - - s.twiddlesRev = make([]fr.Element, len(twiddles)) - copy(s.twiddlesRev, twiddles) - fft.BitReverse(s.twiddlesRev) - - close(s.chNumeratorInit) + // build trace + s.trace = NewTrace(spr, s.domain0) - return nil + return &s, nil } func (s *instance) initBlindingPolynomials() error { @@ -321,7 +306,7 @@ func (s *instance) bsb22Hint(_ *big.Int, ins, outs []*big.Int) error { res := &s.commitmentVal[commDepth] commitmentInfo := s.spr.CommitmentInfo.(constraint.PlonkCommitments)[commDepth] - committedValues := make([]fr.Element, s.pk.Domain[0].Cardinality) + committedValues := make([]fr.Element, s.domain0.Cardinality) offset := s.spr.GetNbPublicVariables() for i := range ins { committedValues[offset+commitmentInfo.Committed[i]].SetBigInt(ins[i]) @@ -336,7 +321,6 @@ func (s *instance) bsb22Hint(_ *big.Int, ins, outs []*big.Int) error { if s.proof.Bsb22Commitments[commDepth], err = kzg.Commit(s.cCommitments[commDepth].Coefficients(), s.pk.KzgLagrange); err != nil { return err } - s.cCommitments[commDepth].ToCanonical(&s.pk.Domain[0]).ToRegular() s.htfFunc.Write(s.proof.Bsb22Commitments[commDepth].Marshal()) hashBts := s.htfFunc.Sum(nil) @@ -395,7 +379,7 @@ func (s *instance) solveConstraints() error { } func (s *instance) completeQk() error { - qk := s.pk.trace.Qk.Clone().ToLagrange(&s.pk.Domain[0]).ToRegular() + qk := s.trace.Qk.Clone() qkCoeffs := qk.Coefficients() wWitness, ok := s.fullWitness.Vector().(fr.Vector) @@ -494,7 +478,7 @@ func (s *instance) commitToPolyAndBlinding(p, b *iop.Polynomial) (commit curve.G commit, err = kzg.Commit(p.Coefficients(), s.pk.KzgLagrange) // we add in the blinding contribution - n := int(s.pk.Domain[0].Cardinality) + n := int(s.domain0.Cardinality) cb := commitBlindingFactor(n, b, s.pk.Kzg) commit.Add(&commit, &cb) @@ -518,20 +502,19 @@ func (s *instance) deriveZeta() (err error) { // evaluateConstraints computes H func (s *instance) evaluateConstraints() (err error) { - // clone polys from the proving key. - s.x[id_Ql] = s.pk.trace.Ql.Clone() - s.x[id_Qr] = s.pk.trace.Qr.Clone() - s.x[id_Qm] = s.pk.trace.Qm.Clone() - s.x[id_Qo] = s.pk.trace.Qo.Clone() - s.x[id_S1] = s.pk.trace.S1.Clone() - s.x[id_S2] = s.pk.trace.S2.Clone() - s.x[id_S3] = s.pk.trace.S3.Clone() + s.x[id_Ql] = s.trace.Ql + s.x[id_Qr] = s.trace.Qr + s.x[id_Qm] = s.trace.Qm + s.x[id_Qo] = s.trace.Qo + s.x[id_S1] = s.trace.S1 + s.x[id_S2] = s.trace.S2 + s.x[id_S3] = s.trace.S3 for i := 0; i < len(s.commitmentInfo); i++ { - s.x[id_Qci+2*i] = s.pk.trace.Qcp[i].Clone() + s.x[id_Qci+2*i] = s.trace.Qcp[i] } - n := s.pk.Domain[0].Cardinality + n := s.domain0.Cardinality lone := make([]fr.Element, n) lone[0].SetOne() @@ -543,7 +526,7 @@ func (s *instance) evaluateConstraints() (err error) { } for i := 0; i < len(s.commitmentInfo); i++ { - s.x[id_Qci+2*i+1] = s.cCommitments[i].Clone() + s.x[id_Qci+2*i+1] = s.cCommitments[i] } // wait for Z to be committed or context done @@ -571,7 +554,7 @@ func (s *instance) evaluateConstraints() (err error) { return err } - s.h, err = divideByXMinusOne(numerator, [2]*fft.Domain{&s.pk.Domain[0], &s.pk.Domain[1]}) + s.h, err = divideByXMinusOne(numerator, [2]*fft.Domain{s.domain0, s.domain1}) if err != nil { return err } @@ -613,11 +596,11 @@ func (s *instance) buildRatioCopyConstraint() (err error) { s.x[id_R], s.x[id_O], }, - s.pk.trace.S, + s.trace.S, s.beta, s.gamma, iop.Form{Basis: iop.Lagrange, Layout: iop.Regular}, - &s.pk.Domain[0], + s.domain0, ) if err != nil { return err @@ -652,17 +635,17 @@ func (s *instance) openZ() (err error) { } func (s *instance) h1() []fr.Element { - h1 := s.h.Coefficients()[:s.pk.Domain[0].Cardinality+2] + h1 := s.h.Coefficients()[:s.domain0.Cardinality+2] return h1 } func (s *instance) h2() []fr.Element { - h2 := s.h.Coefficients()[s.pk.Domain[0].Cardinality+2 : 2*(s.pk.Domain[0].Cardinality+2)] + h2 := s.h.Coefficients()[s.domain0.Cardinality+2 : 2*(s.domain0.Cardinality+2)] return h2 } func (s *instance) h3() []fr.Element { - h3 := s.h.Coefficients()[2*(s.pk.Domain[0].Cardinality+2) : 3*(s.pk.Domain[0].Cardinality+2)] + h3 := s.h.Coefficients()[2*(s.domain0.Cardinality+2) : 3*(s.domain0.Cardinality+2)] return h3 } @@ -675,7 +658,7 @@ func (s *instance) foldH() error { case <-s.chH: } var n big.Int - n.SetUint64(s.pk.Domain[0].Cardinality + 2) + n.SetUint64(s.domain0.Cardinality + 2) var zetaPowerNplusTwo fr.Element zetaPowerNplusTwo.Exp(s.zeta, &n) @@ -691,7 +674,7 @@ func (s *instance) foldH() error { h2 := s.h2() s.foldedH = s.h3() - for i := 0; i < int(s.pk.Domain[0].Cardinality)+2; i++ { + for i := 0; i < int(s.domain0.Cardinality)+2; i++ { s.foldedH[i]. Mul(&s.foldedH[i], &zetaPowerNplusTwo). Add(&s.foldedH[i], &h2[i]). @@ -720,7 +703,7 @@ func (s *instance) computeLinearizedPolynomial() error { for i := 0; i < len(s.commitmentInfo); i++ { go func(i int) { - qcpzeta[i] = s.pk.trace.Qcp[i].Evaluate(s.zeta) + qcpzeta[i] = s.trace.Qcp[i].Evaluate(s.zeta) wg.Done() }(i) } @@ -750,7 +733,7 @@ func (s *instance) computeLinearizedPolynomial() error { wg.Wait() - s.linearizedPolynomial = computeLinearizedPolynomial( + s.linearizedPolynomial = s.innerComputeLinearizedPoly( blzeta, brzeta, bozeta, @@ -775,9 +758,6 @@ func (s *instance) computeLinearizedPolynomial() error { } func (s *instance) batchOpening() error { - polysQcp := coefficients(s.pk.trace.Qcp) - polysToOpen := make([][]fr.Element, 7+len(polysQcp)) - copy(polysToOpen[7:], polysQcp) // wait for LRO to be committed (or ctx.Done()) select { @@ -800,13 +780,17 @@ func (s *instance) batchOpening() error { case <-s.chLinearizedPolynomial: } + polysQcp := coefficients(s.trace.Qcp) + polysToOpen := make([][]fr.Element, 7+len(polysQcp)) + copy(polysToOpen[7:], polysQcp) + polysToOpen[0] = s.foldedH polysToOpen[1] = s.linearizedPolynomial polysToOpen[2] = getBlindedCoefficients(s.x[id_L], s.bp[id_Bl]) polysToOpen[3] = getBlindedCoefficients(s.x[id_R], s.bp[id_Br]) polysToOpen[4] = getBlindedCoefficients(s.x[id_O], s.bp[id_Bo]) - polysToOpen[5] = s.pk.trace.S1.Coefficients() - polysToOpen[6] = s.pk.trace.S2.Coefficients() + polysToOpen[5] = s.trace.S1.Coefficients() + polysToOpen[6] = s.trace.S2.Coefficients() digestsToOpen := make([]curve.G1Affine, len(s.pk.Vk.Qcp)+7) copy(digestsToOpen[7:], s.pk.Vk.Qcp) @@ -835,6 +819,23 @@ func (s *instance) batchOpening() error { // evaluate the full set of constraints, all polynomials in x are back in // canonical regular form at the end func (s *instance) computeNumerator() (*iop.Polynomial, error) { + // init vectors that are used multiple times throughout the computation + n := s.domain0.Cardinality + twiddles0 := make([]fr.Element, n) + if n == 1 { + // edge case + twiddles0[0].SetOne() + } else { + twiddles, err := s.domain0.Twiddles() + if err != nil { + return nil, err + } + copy(twiddles0, twiddles[0]) + w := twiddles0[1] + for i := len(twiddles[0]); i < len(twiddles0); i++ { + twiddles0[i].Mul(&twiddles0[i-1], &w) + } + } // wait for chQk to be closed (or ctx.Done()) select { @@ -843,8 +844,6 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { case <-s.chQk: } - n := s.pk.Domain[0].Cardinality - nbBsbGates := (len(s.x) - id_Qci + 1) >> 1 gateConstraint := func(u ...fr.Element) fr.Element { @@ -867,7 +866,7 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { } var cs, css fr.Element - cs.Set(&s.pk.Domain[1].FrMultiplicativeGen) + cs.Set(&s.domain1.FrMultiplicativeGen) css.Square(&cs) orderingConstraint := func(u ...fr.Element) fr.Element { @@ -899,11 +898,11 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { return res } - rho := int(s.pk.Domain[1].Cardinality / n) + rho := int(s.domain1.Cardinality / n) shifters := make([]fr.Element, rho) - shifters[0].Set(&s.pk.Domain[1].FrMultiplicativeGen) + shifters[0].Set(&s.domain1.FrMultiplicativeGen) for i := 1; i < rho; i++ { - shifters[i].Set(&s.pk.Domain[1].Generator) + shifters[i].Set(&s.domain1.Generator) } // stores the current coset shifter @@ -914,14 +913,13 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { one.SetOne() bn := big.NewInt(int64(n)) - // wait for init go routine - <-s.chNumeratorInit - - cosetTable := s.pk.Domain[0].CosetTable - twiddles := s.pk.Domain[1].Twiddles[0][:n] + cosetTable, err := s.domain0.CosetTable() + if err != nil { + return nil, err + } // init the result polynomial & buffer - cres := s.cres + cres := make([]fr.Element, s.domain1.Cardinality) buf := make([]fr.Element, n) var wgBuf sync.WaitGroup @@ -933,17 +931,17 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { // blind L, R, O, Z, ZS var y fr.Element - y = s.bp[id_Bl].Evaluate(s.twiddles0[i]) + y = s.bp[id_Bl].Evaluate(twiddles0[i]) u[id_L].Add(&u[id_L], &y) - y = s.bp[id_Br].Evaluate(s.twiddles0[i]) + y = s.bp[id_Br].Evaluate(twiddles0[i]) u[id_R].Add(&u[id_R], &y) - y = s.bp[id_Bo].Evaluate(s.twiddles0[i]) + y = s.bp[id_Bo].Evaluate(twiddles0[i]) u[id_O].Add(&u[id_O], &y) - y = s.bp[id_Bz].Evaluate(s.twiddles0[i]) + y = s.bp[id_Bz].Evaluate(twiddles0[i]) u[id_Z].Add(&u[id_Z], &y) // ZS is shifted by 1; need to get correct twiddle - y = s.bp[id_Bz].Evaluate(s.twiddles0[(i+1)%int(n)]) + y = s.bp[id_Bz].Evaluate(twiddles0[(i+1)%int(n)]) u[id_ZS].Add(&u[id_ZS], &y) a := gateConstraint(u...) @@ -953,28 +951,15 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { return c } - // select the correct scaling vector to scale by shifter[i] - selectScalingVector := func(i int, l iop.Layout) []fr.Element { - var w []fr.Element - if i == 0 { - if l == iop.Regular { - w = cosetTable - } else { - w = s.cosetTableRev - } - } else { - if l == iop.Regular { - w = twiddles - } else { - w = s.twiddlesRev - } - } - return w - } + // for the first iteration, the scalingVector is the coset table + scalingVector := cosetTable + scalingVectorRev := make([]fr.Element, len(cosetTable)) + copy(scalingVectorRev, cosetTable) + fft.BitReverse(scalingVectorRev) // pre-computed to compute the bit reverse index // of the result polynomial - m := uint64(s.pk.Domain[1].Cardinality) + m := uint64(s.domain1.Cardinality) mm := uint64(64 - bits.TrailingZeros64(m)) for i := 0; i < rho; i++ { @@ -991,6 +976,17 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { acc.Mul(&acc, &shifters[i]) } } + if i == 1 { + // we have to update the scalingVector; instead of scaling by + // cosets we scale by the twiddles of the large domain. + w := s.domain1.Generator + scalingVector = make([]fr.Element, n) + fft.BuildExpTable(w, scalingVector) + + // reuse memory + copy(scalingVectorRev, scalingVector) + fft.BitReverse(scalingVectorRev) + } // we do **a lot** of FFT here, but on the small domain. // note that for all the polynomials in the proving key @@ -1000,10 +996,15 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { batchApply(s.x, func(p *iop.Polynomial) { nbTasks := calculateNbTasks(len(s.x)-1) * 2 // shift polynomials to be in the correct coset - p.ToCanonical(&s.pk.Domain[0], nbTasks) + p.ToCanonical(s.domain0, nbTasks) // scale by shifter[i] - w := selectScalingVector(i, p.Layout) + var w []fr.Element + if p.Layout == iop.Regular { + w = scalingVector + } else { + w = scalingVectorRev + } cp := p.Coefficients() utils.Parallelize(len(cp), func(start, end int) { @@ -1013,7 +1014,7 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { }, nbTasks) // fft in the correct coset - p.ToLagrange(&s.pk.Domain[0], nbTasks).ToRegular() + p.ToLagrange(s.domain0, nbTasks).ToRegular() }) wgBuf.Wait() @@ -1046,9 +1047,10 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { // scale everything back go func() { - for i := id_ZS; i < len(s.x); i++ { - s.x[i] = nil - } + s.x[id_ID] = nil + s.x[id_LOne] = nil + s.x[id_ZS] = nil + s.x[id_Qk] = nil var cs fr.Element cs.Set(&shifters[0]) @@ -1057,8 +1059,11 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { } cs.Inverse(&cs) - batchApply(s.x[:id_ZS], func(p *iop.Polynomial) { - p.ToCanonical(&s.pk.Domain[0], 8).ToRegular() + batchApply(s.x, func(p *iop.Polynomial) { + if p == nil { + return + } + p.ToCanonical(s.domain0, 8).ToRegular() scalePowers(p, cs) }) @@ -1275,7 +1280,7 @@ func evaluateXnMinusOneDomainBigCoset(domains [2]*fft.Domain) []fr.Element { return res } -// computeLinearizedPolynomial computes the linearized polynomial in canonical basis. +// innerComputeLinearizedPoly computes the linearized polynomial in canonical basis. // The purpose is to commit and open all in one ql, qr, qm, qo, qk. // * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta // * z is the permutation polynomial, zu is Z(μX), the shifted version of Z @@ -1286,8 +1291,8 @@ func evaluateXnMinusOneDomainBigCoset(domains [2]*fft.Domain) []fr.Element { // α²*L₁(ζ)*Z(X) // + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) // + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) -func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, qcpZeta, blindedZCanonical []fr.Element, pi2Canonical [][]fr.Element, pk *ProvingKey) []fr.Element { - +func (s *instance) innerComputeLinearizedPoly(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, qcpZeta, blindedZCanonical []fr.Element, pi2Canonical [][]fr.Element, pk *ProvingKey) []fr.Element { + // TODO @gbotrel rename // first part: individual constraints var rl fr.Element rl.Mul(&rZeta, &lZeta) @@ -1297,12 +1302,12 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, var s1, s2 fr.Element chS1 := make(chan struct{}, 1) go func() { - s1 = pk.trace.S1.Evaluate(zeta) // s1(ζ) + s1 = s.trace.S1.Evaluate(zeta) // s1(ζ) s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) close(chS1) }() // ps2 := iop.NewPolynomial(&pk.S2Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - tmp := pk.trace.S2.Evaluate(zeta) // s2(ζ) + tmp := s.trace.S2.Evaluate(zeta) // s2(ζ) tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) @@ -1321,7 +1326,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, // third part L₁(ζ)*α²*Z var lagrangeZeta, one, den, frNbElmt fr.Element one.SetOne() - nbElmt := int64(pk.Domain[0].Cardinality) + nbElmt := int64(s.domain0.Cardinality) lagrangeZeta.Set(&zeta). Exp(lagrangeZeta, big.NewInt(nbElmt)). Sub(&lagrangeZeta, &one) @@ -1331,17 +1336,19 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) Mul(&lagrangeZeta, &alpha). Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) + Mul(&lagrangeZeta, &s.domain0.CardinalityInv) // (1/n)*α²*L₁(ζ) + + s3canonical := s.trace.S3.Coefficients() - s3canonical := pk.trace.S3.Coefficients() + s.trace.Qk.ToCanonical(s.domain0).ToRegular() utils.Parallelize(len(blindedZCanonical), func(start, end int) { - cql := pk.trace.Ql.Coefficients() - cqr := pk.trace.Qr.Coefficients() - cqm := pk.trace.Qm.Coefficients() - cqo := pk.trace.Qo.Coefficients() - cqk := pk.trace.Qk.Coefficients() + cql := s.trace.Ql.Coefficients() + cqr := s.trace.Qr.Coefficients() + cqm := s.trace.Qm.Coefficients() + cqo := s.trace.Qo.Coefficients() + cqk := s.trace.Qk.Coefficients() var t, t0, t1 fr.Element diff --git a/backend/plonk/bn254/setup.go b/backend/plonk/bn254/setup.go index 4a51f25b79..5d916034f6 100644 --- a/backend/plonk/bn254/setup.go +++ b/backend/plonk/bn254/setup.go @@ -78,34 +78,14 @@ type Trace struct { S []int64 } -// ProvingKey stores the data needed to generate a proof: -// * the commitment scheme -// * ql, prepended with as many ones as they are public inputs -// * qr, qm, qo prepended with as many zeroes as there are public inputs. -// * qk, prepended with as many zeroes as public inputs, to be completed by the prover -// with the list of public inputs. -// * sigma_1, sigma_2, sigma_3 in both basis -// * the copy constraint permutation +// ProvingKey stores the data needed to generate a proof type ProvingKey struct { - // stores ql, qr, qm, qo, qk (-> to be completed by the prover) - // and s1, s2, s3. They are set in canonical basis before generating the proof, they will be used - // for computing the opening proofs (hence the canonical form). The canonical version - // of qk incomplete is used in the linearisation polynomial. - // The polynomials in trace are in canonical basis. - trace Trace - Kzg, KzgLagrange kzg.ProvingKey // Verifying Key is embedded into the proving key (needed by Prove) Vk *VerifyingKey - - // Domains used for the FFTs. - // Domain[0] = small Domain - // Domain[1] = big Domain - Domain [2]fft.Domain } -// TODO modify the signature to receive the SRS in Lagrange form (optional argument ?) func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *VerifyingKey, error) { var pk ProvingKey @@ -114,26 +94,26 @@ func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *Verifyin vk.CommitmentConstraintIndexes = internal.IntSliceToUint64Slice(spr.CommitmentInfo.CommitmentIndexes()) // step 0: set the fft domains - pk.initDomains(spr) - if pk.Domain[0].Cardinality < 2 { + domain := initFFTDomain(spr) + if domain.Cardinality < 2 { return nil, nil, fmt.Errorf("circuit has only %d constraints; unsupported by the current implementation", spr.GetNbConstraints()) } // check the size of the kzg srs. - if len(srs.Pk.G1) < (int(pk.Domain[0].Cardinality) + 3) { // + 3 for the kzg.Open of blinded poly - return nil, nil, fmt.Errorf("kzg srs is too small: got %d, need %d", len(srs.Pk.G1), pk.Domain[0].Cardinality+3) + if len(srs.Pk.G1) < (int(domain.Cardinality) + 3) { // + 3 for the kzg.Open of blinded poly + return nil, nil, fmt.Errorf("kzg srs is too small: got %d, need %d", len(srs.Pk.G1), domain.Cardinality+3) } // same for the lagrange form - if len(srsLagrange.Pk.G1) != int(pk.Domain[0].Cardinality) { - return nil, nil, fmt.Errorf("kzg srs lagrange is too small: got %d, need %d", len(srsLagrange.Pk.G1), pk.Domain[0].Cardinality) + if len(srsLagrange.Pk.G1) != int(domain.Cardinality) { + return nil, nil, fmt.Errorf("kzg srs lagrange is too small: got %d, need %d", len(srsLagrange.Pk.G1), domain.Cardinality) } // step 1: set the verifying key - pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) - vk.Size = pk.Domain[0].Cardinality + vk.CosetShift.Set(&domain.FrMultiplicativeGen) + vk.Size = domain.Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.Domain[0].Generator) + vk.Generator.Set(&domain.Generator) vk.NbPublicVariables = uint64(len(spr.Public)) pk.Kzg.G1 = srs.Pk.G1[:int(vk.Size)+3] @@ -141,22 +121,15 @@ func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *Verifyin vk.Kzg = srs.Vk // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis - BuildTrace(spr, &pk.trace) - // step 3: build the permutation and build the polynomials S1, S2, S3 to encode the permutation. // Note: at this stage, the permutation takes in account the placeholders - nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - buildPermutation(spr, &pk.trace, nbVariables) - s := computePermutationPolynomials(&pk.trace, &pk.Domain[0]) - pk.trace.S1 = s[0] - pk.trace.S2 = s[1] - pk.trace.S3 = s[2] + trace := NewTrace(spr, domain) // step 4: commit to s1, s2, s3, ql, qr, qm, qo, and (the incomplete version of) qk. // All the above polynomials are expressed in canonical basis afterwards. This is why // we save lqk before, because the prover needs to complete it in Lagrange form, and // then express it on the Lagrange coset basis. - if err := commitTrace(&pk.trace, &pk); err != nil { + if err := vk.commitTrace(trace, domain, pk.KzgLagrange); err != nil { return nil, nil, err } @@ -173,9 +146,13 @@ func (pk *ProvingKey) VerifyingKey() interface{} { return pk.Vk } -// BuildTrace fills the constant columns ql, qr, qm, qo, qk from the sparser1cs. -// Size is the size of the system that is nb_constraints+nb_public_variables -func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { +// NewTrace returns a new Trace object from the constraint system. +// It fills the constant columns ql, qr, qm, qo, qk, and qcp with the +// coefficients of the constraints. +// Size is the size of the system that is next power of 2 (nb_constraints+nb_public_variables) +// The permutation is also computed and stored in the Trace. +func NewTrace(spr *cs.SparseR1CS, domain *fft.Domain) *Trace { + var trace Trace nbConstraints := spr.GetNbConstraints() sizeSystem := uint64(nbConstraints + len(spr.Public)) @@ -211,85 +188,75 @@ func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} - pt.Ql = iop.NewPolynomial(&ql, lagReg) - pt.Qr = iop.NewPolynomial(&qr, lagReg) - pt.Qm = iop.NewPolynomial(&qm, lagReg) - pt.Qo = iop.NewPolynomial(&qo, lagReg) - pt.Qk = iop.NewPolynomial(&qk, lagReg) - pt.Qcp = make([]*iop.Polynomial, len(qcp)) + trace.Ql = iop.NewPolynomial(&ql, lagReg) + trace.Qr = iop.NewPolynomial(&qr, lagReg) + trace.Qm = iop.NewPolynomial(&qm, lagReg) + trace.Qo = iop.NewPolynomial(&qo, lagReg) + trace.Qk = iop.NewPolynomial(&qk, lagReg) + trace.Qcp = make([]*iop.Polynomial, len(qcp)) for i := range commitmentInfo { qcp[i] = make([]fr.Element, size) for _, committed := range commitmentInfo[i].Committed { qcp[i][offset+committed].SetOne() } - pt.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) + trace.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) } + + // build the permutation and build the polynomials S1, S2, S3 to encode the permutation. + // Note: at this stage, the permutation takes in account the placeholders + nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + buildPermutation(spr, &trace, nbVariables) + s := computePermutationPolynomials(&trace, domain) + trace.S1 = s[0] + trace.S2 = s[1] + trace.S3 = s[2] + + return &trace } // commitTrace commits to every polynomial in the trace, and put // the commitments int the verifying key. -func commitTrace(trace *Trace, pk *ProvingKey) error { - - trace.Ql.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qr.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qm.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qo.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qk.ToCanonical(&pk.Domain[0]).ToRegular() // -> qk is not complete - trace.S1.ToCanonical(&pk.Domain[0]).ToRegular() - trace.S2.ToCanonical(&pk.Domain[0]).ToRegular() - trace.S3.ToCanonical(&pk.Domain[0]).ToRegular() +func (vk *VerifyingKey) commitTrace(trace *Trace, domain *fft.Domain, srsPk kzg.ProvingKey) error { var err error - pk.Vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) + vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) for i := range trace.Qcp { - trace.Qcp[i].ToCanonical(&pk.Domain[0]).ToRegular() - if pk.Vk.Qcp[i], err = kzg.Commit(pk.trace.Qcp[i].Coefficients(), pk.Kzg); err != nil { + if vk.Qcp[i], err = kzg.Commit(trace.Qcp[i].Coefficients(), srsPk); err != nil { return err } } - if pk.Vk.Ql, err = kzg.Commit(pk.trace.Ql.Coefficients(), pk.Kzg); err != nil { + if vk.Ql, err = kzg.Commit(trace.Ql.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qr, err = kzg.Commit(pk.trace.Qr.Coefficients(), pk.Kzg); err != nil { + if vk.Qr, err = kzg.Commit(trace.Qr.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qm, err = kzg.Commit(pk.trace.Qm.Coefficients(), pk.Kzg); err != nil { + if vk.Qm, err = kzg.Commit(trace.Qm.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qo, err = kzg.Commit(pk.trace.Qo.Coefficients(), pk.Kzg); err != nil { + if vk.Qo, err = kzg.Commit(trace.Qo.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qk, err = kzg.Commit(pk.trace.Qk.Coefficients(), pk.Kzg); err != nil { + if vk.Qk, err = kzg.Commit(trace.Qk.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.S[0], err = kzg.Commit(pk.trace.S1.Coefficients(), pk.Kzg); err != nil { + if vk.S[0], err = kzg.Commit(trace.S1.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.S[1], err = kzg.Commit(pk.trace.S2.Coefficients(), pk.Kzg); err != nil { + if vk.S[1], err = kzg.Commit(trace.S2.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.S[2], err = kzg.Commit(pk.trace.S3.Coefficients(), pk.Kzg); err != nil { + if vk.S[2], err = kzg.Commit(trace.S3.Coefficients(), srsPk); err != nil { return err } return nil } -func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { - +func initFFTDomain(spr *cs.SparseR1CS) *fft.Domain { nbConstraints := spr.GetNbConstraints() sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints - pk.Domain[0] = *fft.NewDomain(sizeSystem) - - // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, - // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases - // except when n<6. - if sizeSystem < 6 { - pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) - } else { - pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) - } - + return fft.NewDomain(sizeSystem, fft.WithoutPrecompute()) } // buildPermutation builds the Permutation associated with a circuit. @@ -304,10 +271,10 @@ func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { // The permutation is encoded as a slice s of size 3*size(l), where the // i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab // like this: for i in tab: tab[i] = tab[permutation[i]] -func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { +func buildPermutation(spr *cs.SparseR1CS, trace *Trace, nbVariables int) { // nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - sizeSolution := len(pt.Ql.Coefficients()) + sizeSolution := len(trace.Ql.Coefficients()) sizePermutation := 3 * sizeSolution // init permutation @@ -357,13 +324,13 @@ func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { } } - pt.S = permutation + trace.S = permutation } // computePermutationPolynomials computes the LDE (Lagrange basis) of the permutation. // We let the permutation act on || u || u^{2}, split the result in 3 parts, // and interpolate each of the 3 parts on . -func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polynomial { +func computePermutationPolynomials(trace *Trace, domain *fft.Domain) [3]*iop.Polynomial { nbElmts := int(domain.Cardinality) @@ -377,9 +344,9 @@ func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polyno s2Canonical := make([]fr.Element, nbElmts) s3Canonical := make([]fr.Element, nbElmts) for i := 0; i < nbElmts; i++ { - s1Canonical[i].Set(&evaluationIDSmallDomain[pt.S[i]]) - s2Canonical[i].Set(&evaluationIDSmallDomain[pt.S[nbElmts+i]]) - s3Canonical[i].Set(&evaluationIDSmallDomain[pt.S[2*nbElmts+i]]) + s1Canonical[i].Set(&evaluationIDSmallDomain[trace.S[i]]) + s2Canonical[i].Set(&evaluationIDSmallDomain[trace.S[nbElmts+i]]) + s3Canonical[i].Set(&evaluationIDSmallDomain[trace.S[2*nbElmts+i]]) } lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} diff --git a/backend/plonk/bw6-633/marshal.go b/backend/plonk/bw6-633/marshal.go index d0cac16c98..73d1828266 100644 --- a/backend/plonk/bw6-633/marshal.go +++ b/backend/plonk/bw6-633/marshal.go @@ -19,10 +19,6 @@ package plonk import ( curve "github.com/consensys/gnark-crypto/ecc/bw6-633" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" - - "errors" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/iop" "github.com/consensys/gnark-crypto/ecc/bw6-633/kzg" "io" ) @@ -116,19 +112,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, withCompression bool) (n int64, err e return } - // fft domains - n2, err := pk.Domain[0].WriteTo(w) - if err != nil { - return - } - n += n2 - - n2, err = pk.Domain[1].WriteTo(w) - if err != nil { - return - } - n += n2 - + var n2 int64 // KZG key if withCompression { n2, err = pk.Kzg.WriteTo(w) @@ -149,35 +133,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, withCompression bool) (n int64, err e } n += n2 - // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) - if len(pk.trace.S) != (3 * int(pk.Domain[0].Cardinality)) { - return n, errors.New("invalid permutation size, expected 3*domain cardinality") - } - - enc := curve.NewEncoder(w) - // note: type Polynomial, which is handled by default binary.Write(...) op and doesn't - // encode the size (nor does it convert from Montgomery to Regular form) - // so we explicitly transmit []fr.Element - toEncode := []interface{}{ - pk.trace.Ql.Coefficients(), - pk.trace.Qr.Coefficients(), - pk.trace.Qm.Coefficients(), - pk.trace.Qo.Coefficients(), - pk.trace.Qk.Coefficients(), - coefficients(pk.trace.Qcp), - pk.trace.S1.Coefficients(), - pk.trace.S2.Coefficients(), - pk.trace.S3.Coefficients(), - pk.trace.S, - } - - for _, v := range toEncode { - if err := enc.Encode(v); err != nil { - return n + enc.BytesWritten(), err - } - } - - return n + enc.BytesWritten(), nil + return n, nil } // ReadFrom reads from binary representation in r into ProvingKey @@ -197,18 +153,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err return n, err } - n2, err, chDomain0 := pk.Domain[0].AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - - n2, err, chDomain1 := pk.Domain[1].AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - + var n2 int64 if withSubgroupChecks { n2, err = pk.Kzg.ReadFrom(r) } else { @@ -224,98 +169,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err n2, err = pk.KzgLagrange.UnsafeReadFrom(r) } n += n2 - if err != nil { - return n, err - } - - pk.trace.S = make([]int64, 3*pk.Domain[0].Cardinality) - - dec := curve.NewDecoder(r) - - var ql, qr, qm, qo, qk, s1, s2, s3 []fr.Element - var qcp [][]fr.Element - - // TODO @gbotrel: this is a bit ugly, we should probably refactor this. - // The order of the variables is important, as it matches the order in which they are - // encoded in the WriteTo(...) method. - - // Note: instead of calling dec.Decode(...) for each of the above variables, - // we call AsyncReadFrom when possible which allows to consume bytes from the reader - // and perform the decoding in parallel - - type v struct { - data *fr.Vector - chErr chan error - } - - vectors := make([]v, 8) - vectors[0] = v{data: (*fr.Vector)(&ql)} - vectors[1] = v{data: (*fr.Vector)(&qr)} - vectors[2] = v{data: (*fr.Vector)(&qm)} - vectors[3] = v{data: (*fr.Vector)(&qo)} - vectors[4] = v{data: (*fr.Vector)(&qk)} - vectors[5] = v{data: (*fr.Vector)(&s1)} - vectors[6] = v{data: (*fr.Vector)(&s2)} - vectors[7] = v{data: (*fr.Vector)(&s3)} - - // read ql, qr, qm, qo, qk - for i := 0; i < 5; i++ { - n2, err, ch := vectors[i].data.AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - vectors[i].chErr = ch - } - - // read qcp - if err := dec.Decode(&qcp); err != nil { - return n + dec.BytesRead(), err - } - - // read lqk, s1, s2, s3 - for i := 5; i < 8; i++ { - n2, err, ch := vectors[i].data.AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - vectors[i].chErr = ch - } - - // read pk.Trace.S - if err := dec.Decode(&pk.trace.S); err != nil { - return n + dec.BytesRead(), err - } - - // wait for all AsyncReadFrom(...) to complete - for i := range vectors { - if err := <-vectors[i].chErr; err != nil { - return n, err - } - } - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - pk.trace.Ql = iop.NewPolynomial(&ql, canReg) - pk.trace.Qr = iop.NewPolynomial(&qr, canReg) - pk.trace.Qm = iop.NewPolynomial(&qm, canReg) - pk.trace.Qo = iop.NewPolynomial(&qo, canReg) - pk.trace.Qk = iop.NewPolynomial(&qk, canReg) - pk.trace.S1 = iop.NewPolynomial(&s1, canReg) - pk.trace.S2 = iop.NewPolynomial(&s2, canReg) - pk.trace.S3 = iop.NewPolynomial(&s3, canReg) - - pk.trace.Qcp = make([]*iop.Polynomial, len(qcp)) - for i := range qcp { - pk.trace.Qcp[i] = iop.NewPolynomial(&qcp[i], canReg) - } - - // wait for FFT to be precomputed - <-chDomain0 - <-chDomain1 - - return n + dec.BytesRead(), nil - + return n, err } // WriteTo writes binary encoding of VerifyingKey to w diff --git a/backend/plonk/bw6-633/marshal_test.go b/backend/plonk/bw6-633/marshal_test.go index c47269443c..9804f8466e 100644 --- a/backend/plonk/bw6-633/marshal_test.go +++ b/backend/plonk/bw6-633/marshal_test.go @@ -20,9 +20,6 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bw6-633" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" - - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/iop" "github.com/consensys/gnark/io" "math/big" "math/rand" @@ -60,8 +57,6 @@ func (pk *ProvingKey) randomize() { var vk VerifyingKey vk.randomize() pk.Vk = &vk - pk.Domain[0] = *fft.NewDomain(32) - pk.Domain[1] = *fft.NewDomain(4 * 32) pk.Kzg.G1 = make([]curve.G1Affine, 32) pk.KzgLagrange.G1 = make([]curve.G1Affine, 32) @@ -70,36 +65,6 @@ func (pk *ProvingKey) randomize() { pk.KzgLagrange.G1[i] = randomG1Point() } - n := int(pk.Domain[0].Cardinality) - ql := randomScalars(n) - qr := randomScalars(n) - qm := randomScalars(n) - qo := randomScalars(n) - qk := randomScalars(n) - s1 := randomScalars(n) - s2 := randomScalars(n) - s3 := randomScalars(n) - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - pk.trace.Ql = iop.NewPolynomial(&ql, canReg) - pk.trace.Qr = iop.NewPolynomial(&qr, canReg) - pk.trace.Qm = iop.NewPolynomial(&qm, canReg) - pk.trace.Qo = iop.NewPolynomial(&qo, canReg) - pk.trace.Qk = iop.NewPolynomial(&qk, canReg) - pk.trace.S1 = iop.NewPolynomial(&s1, canReg) - pk.trace.S2 = iop.NewPolynomial(&s2, canReg) - pk.trace.S3 = iop.NewPolynomial(&s3, canReg) - - pk.trace.Qcp = make([]*iop.Polynomial, rand.Intn(4)) //#nosec G404 weak rng is fine here - for i := range pk.trace.Qcp { - qcp := randomScalars(rand.Intn(n / 4)) //#nosec G404 weak rng is fine here - pk.trace.Qcp[i] = iop.NewPolynomial(&qcp, canReg) - } - - pk.trace.S = make([]int64, 3*pk.Domain[0].Cardinality) - pk.trace.S[0] = -12 - pk.trace.S[len(pk.trace.S)-1] = 8888 - } func (vk *VerifyingKey) randomize() { diff --git a/backend/plonk/bw6-633/prove.go b/backend/plonk/bw6-633/prove.go index 6b72f71fb1..fc25a75518 100644 --- a/backend/plonk/bw6-633/prove.go +++ b/backend/plonk/bw6-633/prove.go @@ -138,9 +138,6 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts // solve constraints g.Go(instance.solveConstraints) - // compute numerator data - g.Go(instance.initComputeNumerator) - // complete qk g.Go(instance.completeQk) @@ -211,9 +208,6 @@ type instance struct { // challenges gamma, beta, alpha, zeta fr.Element - // compute numerator data - cres, twiddles0, cosetTableRev, twiddlesRev []fr.Element - // channel to wait for the steps chLRO, chQk, @@ -224,8 +218,11 @@ type instance struct { chZOpening, chLinearizedPolynomial, chFoldedH, - chNumeratorInit, chGammaBeta chan struct{} + + domain0, domain1 *fft.Domain + + trace *Trace } func newInstance(ctx context.Context, spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts *backend.ProverConfig) (*instance, error) { @@ -253,43 +250,31 @@ func newInstance(ctx context.Context, spr *cs.SparseR1CS, pk *ProvingKey, fullWi chLinearizedPolynomial: make(chan struct{}, 1), chFoldedH: make(chan struct{}, 1), chRestoreLRO: make(chan struct{}, 1), - chNumeratorInit: make(chan struct{}, 1), } s.initBSB22Commitments() s.setupGKRHints() s.x = make([]*iop.Polynomial, id_Qci+2*len(s.commitmentInfo)) - return &s, nil -} + // init fft domains + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints + s.domain0 = fft.NewDomain(sizeSystem) -func (s *instance) initComputeNumerator() error { - n := s.pk.Domain[0].Cardinality - s.cres = make([]fr.Element, s.pk.Domain[1].Cardinality) - s.twiddles0 = make([]fr.Element, n) - if n == 1 { - // edge case - s.twiddles0[0].SetOne() + // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, + // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases + // except when n<6. + if sizeSystem < 6 { + s.domain1 = fft.NewDomain(8*sizeSystem, fft.WithoutPrecompute()) } else { - copy(s.twiddles0, s.pk.Domain[0].Twiddles[0]) - for i := len(s.pk.Domain[0].Twiddles[0]); i < len(s.twiddles0); i++ { - s.twiddles0[i].Mul(&s.twiddles0[i-1], &s.twiddles0[1]) - } + s.domain1 = fft.NewDomain(4*sizeSystem, fft.WithoutPrecompute()) } + // TODO @gbotrel domain1 is used for only 1 FFT --> precomputing the twiddles + // and storing them in memory is costly given its size. --> do a FFT on the fly - cosetTable := s.pk.Domain[0].CosetTable - twiddles := s.pk.Domain[1].Twiddles[0][:n] - - s.cosetTableRev = make([]fr.Element, len(cosetTable)) - copy(s.cosetTableRev, cosetTable) - fft.BitReverse(s.cosetTableRev) - - s.twiddlesRev = make([]fr.Element, len(twiddles)) - copy(s.twiddlesRev, twiddles) - fft.BitReverse(s.twiddlesRev) - - close(s.chNumeratorInit) + // build trace + s.trace = NewTrace(spr, s.domain0) - return nil + return &s, nil } func (s *instance) initBlindingPolynomials() error { @@ -321,7 +306,7 @@ func (s *instance) bsb22Hint(_ *big.Int, ins, outs []*big.Int) error { res := &s.commitmentVal[commDepth] commitmentInfo := s.spr.CommitmentInfo.(constraint.PlonkCommitments)[commDepth] - committedValues := make([]fr.Element, s.pk.Domain[0].Cardinality) + committedValues := make([]fr.Element, s.domain0.Cardinality) offset := s.spr.GetNbPublicVariables() for i := range ins { committedValues[offset+commitmentInfo.Committed[i]].SetBigInt(ins[i]) @@ -336,7 +321,6 @@ func (s *instance) bsb22Hint(_ *big.Int, ins, outs []*big.Int) error { if s.proof.Bsb22Commitments[commDepth], err = kzg.Commit(s.cCommitments[commDepth].Coefficients(), s.pk.KzgLagrange); err != nil { return err } - s.cCommitments[commDepth].ToCanonical(&s.pk.Domain[0]).ToRegular() s.htfFunc.Write(s.proof.Bsb22Commitments[commDepth].Marshal()) hashBts := s.htfFunc.Sum(nil) @@ -395,7 +379,7 @@ func (s *instance) solveConstraints() error { } func (s *instance) completeQk() error { - qk := s.pk.trace.Qk.Clone().ToLagrange(&s.pk.Domain[0]).ToRegular() + qk := s.trace.Qk.Clone() qkCoeffs := qk.Coefficients() wWitness, ok := s.fullWitness.Vector().(fr.Vector) @@ -494,7 +478,7 @@ func (s *instance) commitToPolyAndBlinding(p, b *iop.Polynomial) (commit curve.G commit, err = kzg.Commit(p.Coefficients(), s.pk.KzgLagrange) // we add in the blinding contribution - n := int(s.pk.Domain[0].Cardinality) + n := int(s.domain0.Cardinality) cb := commitBlindingFactor(n, b, s.pk.Kzg) commit.Add(&commit, &cb) @@ -518,20 +502,19 @@ func (s *instance) deriveZeta() (err error) { // evaluateConstraints computes H func (s *instance) evaluateConstraints() (err error) { - // clone polys from the proving key. - s.x[id_Ql] = s.pk.trace.Ql.Clone() - s.x[id_Qr] = s.pk.trace.Qr.Clone() - s.x[id_Qm] = s.pk.trace.Qm.Clone() - s.x[id_Qo] = s.pk.trace.Qo.Clone() - s.x[id_S1] = s.pk.trace.S1.Clone() - s.x[id_S2] = s.pk.trace.S2.Clone() - s.x[id_S3] = s.pk.trace.S3.Clone() + s.x[id_Ql] = s.trace.Ql + s.x[id_Qr] = s.trace.Qr + s.x[id_Qm] = s.trace.Qm + s.x[id_Qo] = s.trace.Qo + s.x[id_S1] = s.trace.S1 + s.x[id_S2] = s.trace.S2 + s.x[id_S3] = s.trace.S3 for i := 0; i < len(s.commitmentInfo); i++ { - s.x[id_Qci+2*i] = s.pk.trace.Qcp[i].Clone() + s.x[id_Qci+2*i] = s.trace.Qcp[i] } - n := s.pk.Domain[0].Cardinality + n := s.domain0.Cardinality lone := make([]fr.Element, n) lone[0].SetOne() @@ -543,7 +526,7 @@ func (s *instance) evaluateConstraints() (err error) { } for i := 0; i < len(s.commitmentInfo); i++ { - s.x[id_Qci+2*i+1] = s.cCommitments[i].Clone() + s.x[id_Qci+2*i+1] = s.cCommitments[i] } // wait for Z to be committed or context done @@ -571,7 +554,7 @@ func (s *instance) evaluateConstraints() (err error) { return err } - s.h, err = divideByXMinusOne(numerator, [2]*fft.Domain{&s.pk.Domain[0], &s.pk.Domain[1]}) + s.h, err = divideByXMinusOne(numerator, [2]*fft.Domain{s.domain0, s.domain1}) if err != nil { return err } @@ -613,11 +596,11 @@ func (s *instance) buildRatioCopyConstraint() (err error) { s.x[id_R], s.x[id_O], }, - s.pk.trace.S, + s.trace.S, s.beta, s.gamma, iop.Form{Basis: iop.Lagrange, Layout: iop.Regular}, - &s.pk.Domain[0], + s.domain0, ) if err != nil { return err @@ -652,17 +635,17 @@ func (s *instance) openZ() (err error) { } func (s *instance) h1() []fr.Element { - h1 := s.h.Coefficients()[:s.pk.Domain[0].Cardinality+2] + h1 := s.h.Coefficients()[:s.domain0.Cardinality+2] return h1 } func (s *instance) h2() []fr.Element { - h2 := s.h.Coefficients()[s.pk.Domain[0].Cardinality+2 : 2*(s.pk.Domain[0].Cardinality+2)] + h2 := s.h.Coefficients()[s.domain0.Cardinality+2 : 2*(s.domain0.Cardinality+2)] return h2 } func (s *instance) h3() []fr.Element { - h3 := s.h.Coefficients()[2*(s.pk.Domain[0].Cardinality+2) : 3*(s.pk.Domain[0].Cardinality+2)] + h3 := s.h.Coefficients()[2*(s.domain0.Cardinality+2) : 3*(s.domain0.Cardinality+2)] return h3 } @@ -675,7 +658,7 @@ func (s *instance) foldH() error { case <-s.chH: } var n big.Int - n.SetUint64(s.pk.Domain[0].Cardinality + 2) + n.SetUint64(s.domain0.Cardinality + 2) var zetaPowerNplusTwo fr.Element zetaPowerNplusTwo.Exp(s.zeta, &n) @@ -691,7 +674,7 @@ func (s *instance) foldH() error { h2 := s.h2() s.foldedH = s.h3() - for i := 0; i < int(s.pk.Domain[0].Cardinality)+2; i++ { + for i := 0; i < int(s.domain0.Cardinality)+2; i++ { s.foldedH[i]. Mul(&s.foldedH[i], &zetaPowerNplusTwo). Add(&s.foldedH[i], &h2[i]). @@ -720,7 +703,7 @@ func (s *instance) computeLinearizedPolynomial() error { for i := 0; i < len(s.commitmentInfo); i++ { go func(i int) { - qcpzeta[i] = s.pk.trace.Qcp[i].Evaluate(s.zeta) + qcpzeta[i] = s.trace.Qcp[i].Evaluate(s.zeta) wg.Done() }(i) } @@ -750,7 +733,7 @@ func (s *instance) computeLinearizedPolynomial() error { wg.Wait() - s.linearizedPolynomial = computeLinearizedPolynomial( + s.linearizedPolynomial = s.innerComputeLinearizedPoly( blzeta, brzeta, bozeta, @@ -775,9 +758,6 @@ func (s *instance) computeLinearizedPolynomial() error { } func (s *instance) batchOpening() error { - polysQcp := coefficients(s.pk.trace.Qcp) - polysToOpen := make([][]fr.Element, 7+len(polysQcp)) - copy(polysToOpen[7:], polysQcp) // wait for LRO to be committed (or ctx.Done()) select { @@ -800,13 +780,17 @@ func (s *instance) batchOpening() error { case <-s.chLinearizedPolynomial: } + polysQcp := coefficients(s.trace.Qcp) + polysToOpen := make([][]fr.Element, 7+len(polysQcp)) + copy(polysToOpen[7:], polysQcp) + polysToOpen[0] = s.foldedH polysToOpen[1] = s.linearizedPolynomial polysToOpen[2] = getBlindedCoefficients(s.x[id_L], s.bp[id_Bl]) polysToOpen[3] = getBlindedCoefficients(s.x[id_R], s.bp[id_Br]) polysToOpen[4] = getBlindedCoefficients(s.x[id_O], s.bp[id_Bo]) - polysToOpen[5] = s.pk.trace.S1.Coefficients() - polysToOpen[6] = s.pk.trace.S2.Coefficients() + polysToOpen[5] = s.trace.S1.Coefficients() + polysToOpen[6] = s.trace.S2.Coefficients() digestsToOpen := make([]curve.G1Affine, len(s.pk.Vk.Qcp)+7) copy(digestsToOpen[7:], s.pk.Vk.Qcp) @@ -835,6 +819,23 @@ func (s *instance) batchOpening() error { // evaluate the full set of constraints, all polynomials in x are back in // canonical regular form at the end func (s *instance) computeNumerator() (*iop.Polynomial, error) { + // init vectors that are used multiple times throughout the computation + n := s.domain0.Cardinality + twiddles0 := make([]fr.Element, n) + if n == 1 { + // edge case + twiddles0[0].SetOne() + } else { + twiddles, err := s.domain0.Twiddles() + if err != nil { + return nil, err + } + copy(twiddles0, twiddles[0]) + w := twiddles0[1] + for i := len(twiddles[0]); i < len(twiddles0); i++ { + twiddles0[i].Mul(&twiddles0[i-1], &w) + } + } // wait for chQk to be closed (or ctx.Done()) select { @@ -843,8 +844,6 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { case <-s.chQk: } - n := s.pk.Domain[0].Cardinality - nbBsbGates := (len(s.x) - id_Qci + 1) >> 1 gateConstraint := func(u ...fr.Element) fr.Element { @@ -867,7 +866,7 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { } var cs, css fr.Element - cs.Set(&s.pk.Domain[1].FrMultiplicativeGen) + cs.Set(&s.domain1.FrMultiplicativeGen) css.Square(&cs) orderingConstraint := func(u ...fr.Element) fr.Element { @@ -899,11 +898,11 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { return res } - rho := int(s.pk.Domain[1].Cardinality / n) + rho := int(s.domain1.Cardinality / n) shifters := make([]fr.Element, rho) - shifters[0].Set(&s.pk.Domain[1].FrMultiplicativeGen) + shifters[0].Set(&s.domain1.FrMultiplicativeGen) for i := 1; i < rho; i++ { - shifters[i].Set(&s.pk.Domain[1].Generator) + shifters[i].Set(&s.domain1.Generator) } // stores the current coset shifter @@ -914,14 +913,13 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { one.SetOne() bn := big.NewInt(int64(n)) - // wait for init go routine - <-s.chNumeratorInit - - cosetTable := s.pk.Domain[0].CosetTable - twiddles := s.pk.Domain[1].Twiddles[0][:n] + cosetTable, err := s.domain0.CosetTable() + if err != nil { + return nil, err + } // init the result polynomial & buffer - cres := s.cres + cres := make([]fr.Element, s.domain1.Cardinality) buf := make([]fr.Element, n) var wgBuf sync.WaitGroup @@ -933,17 +931,17 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { // blind L, R, O, Z, ZS var y fr.Element - y = s.bp[id_Bl].Evaluate(s.twiddles0[i]) + y = s.bp[id_Bl].Evaluate(twiddles0[i]) u[id_L].Add(&u[id_L], &y) - y = s.bp[id_Br].Evaluate(s.twiddles0[i]) + y = s.bp[id_Br].Evaluate(twiddles0[i]) u[id_R].Add(&u[id_R], &y) - y = s.bp[id_Bo].Evaluate(s.twiddles0[i]) + y = s.bp[id_Bo].Evaluate(twiddles0[i]) u[id_O].Add(&u[id_O], &y) - y = s.bp[id_Bz].Evaluate(s.twiddles0[i]) + y = s.bp[id_Bz].Evaluate(twiddles0[i]) u[id_Z].Add(&u[id_Z], &y) // ZS is shifted by 1; need to get correct twiddle - y = s.bp[id_Bz].Evaluate(s.twiddles0[(i+1)%int(n)]) + y = s.bp[id_Bz].Evaluate(twiddles0[(i+1)%int(n)]) u[id_ZS].Add(&u[id_ZS], &y) a := gateConstraint(u...) @@ -953,28 +951,15 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { return c } - // select the correct scaling vector to scale by shifter[i] - selectScalingVector := func(i int, l iop.Layout) []fr.Element { - var w []fr.Element - if i == 0 { - if l == iop.Regular { - w = cosetTable - } else { - w = s.cosetTableRev - } - } else { - if l == iop.Regular { - w = twiddles - } else { - w = s.twiddlesRev - } - } - return w - } + // for the first iteration, the scalingVector is the coset table + scalingVector := cosetTable + scalingVectorRev := make([]fr.Element, len(cosetTable)) + copy(scalingVectorRev, cosetTable) + fft.BitReverse(scalingVectorRev) // pre-computed to compute the bit reverse index // of the result polynomial - m := uint64(s.pk.Domain[1].Cardinality) + m := uint64(s.domain1.Cardinality) mm := uint64(64 - bits.TrailingZeros64(m)) for i := 0; i < rho; i++ { @@ -991,6 +976,17 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { acc.Mul(&acc, &shifters[i]) } } + if i == 1 { + // we have to update the scalingVector; instead of scaling by + // cosets we scale by the twiddles of the large domain. + w := s.domain1.Generator + scalingVector = make([]fr.Element, n) + fft.BuildExpTable(w, scalingVector) + + // reuse memory + copy(scalingVectorRev, scalingVector) + fft.BitReverse(scalingVectorRev) + } // we do **a lot** of FFT here, but on the small domain. // note that for all the polynomials in the proving key @@ -1000,10 +996,15 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { batchApply(s.x, func(p *iop.Polynomial) { nbTasks := calculateNbTasks(len(s.x)-1) * 2 // shift polynomials to be in the correct coset - p.ToCanonical(&s.pk.Domain[0], nbTasks) + p.ToCanonical(s.domain0, nbTasks) // scale by shifter[i] - w := selectScalingVector(i, p.Layout) + var w []fr.Element + if p.Layout == iop.Regular { + w = scalingVector + } else { + w = scalingVectorRev + } cp := p.Coefficients() utils.Parallelize(len(cp), func(start, end int) { @@ -1013,7 +1014,7 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { }, nbTasks) // fft in the correct coset - p.ToLagrange(&s.pk.Domain[0], nbTasks).ToRegular() + p.ToLagrange(s.domain0, nbTasks).ToRegular() }) wgBuf.Wait() @@ -1046,9 +1047,10 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { // scale everything back go func() { - for i := id_ZS; i < len(s.x); i++ { - s.x[i] = nil - } + s.x[id_ID] = nil + s.x[id_LOne] = nil + s.x[id_ZS] = nil + s.x[id_Qk] = nil var cs fr.Element cs.Set(&shifters[0]) @@ -1057,8 +1059,11 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { } cs.Inverse(&cs) - batchApply(s.x[:id_ZS], func(p *iop.Polynomial) { - p.ToCanonical(&s.pk.Domain[0], 8).ToRegular() + batchApply(s.x, func(p *iop.Polynomial) { + if p == nil { + return + } + p.ToCanonical(s.domain0, 8).ToRegular() scalePowers(p, cs) }) @@ -1275,7 +1280,7 @@ func evaluateXnMinusOneDomainBigCoset(domains [2]*fft.Domain) []fr.Element { return res } -// computeLinearizedPolynomial computes the linearized polynomial in canonical basis. +// innerComputeLinearizedPoly computes the linearized polynomial in canonical basis. // The purpose is to commit and open all in one ql, qr, qm, qo, qk. // * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta // * z is the permutation polynomial, zu is Z(μX), the shifted version of Z @@ -1286,8 +1291,8 @@ func evaluateXnMinusOneDomainBigCoset(domains [2]*fft.Domain) []fr.Element { // α²*L₁(ζ)*Z(X) // + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) // + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) -func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, qcpZeta, blindedZCanonical []fr.Element, pi2Canonical [][]fr.Element, pk *ProvingKey) []fr.Element { - +func (s *instance) innerComputeLinearizedPoly(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, qcpZeta, blindedZCanonical []fr.Element, pi2Canonical [][]fr.Element, pk *ProvingKey) []fr.Element { + // TODO @gbotrel rename // first part: individual constraints var rl fr.Element rl.Mul(&rZeta, &lZeta) @@ -1297,12 +1302,12 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, var s1, s2 fr.Element chS1 := make(chan struct{}, 1) go func() { - s1 = pk.trace.S1.Evaluate(zeta) // s1(ζ) + s1 = s.trace.S1.Evaluate(zeta) // s1(ζ) s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) close(chS1) }() // ps2 := iop.NewPolynomial(&pk.S2Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - tmp := pk.trace.S2.Evaluate(zeta) // s2(ζ) + tmp := s.trace.S2.Evaluate(zeta) // s2(ζ) tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) @@ -1321,7 +1326,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, // third part L₁(ζ)*α²*Z var lagrangeZeta, one, den, frNbElmt fr.Element one.SetOne() - nbElmt := int64(pk.Domain[0].Cardinality) + nbElmt := int64(s.domain0.Cardinality) lagrangeZeta.Set(&zeta). Exp(lagrangeZeta, big.NewInt(nbElmt)). Sub(&lagrangeZeta, &one) @@ -1331,17 +1336,19 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) Mul(&lagrangeZeta, &alpha). Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) + Mul(&lagrangeZeta, &s.domain0.CardinalityInv) // (1/n)*α²*L₁(ζ) + + s3canonical := s.trace.S3.Coefficients() - s3canonical := pk.trace.S3.Coefficients() + s.trace.Qk.ToCanonical(s.domain0).ToRegular() utils.Parallelize(len(blindedZCanonical), func(start, end int) { - cql := pk.trace.Ql.Coefficients() - cqr := pk.trace.Qr.Coefficients() - cqm := pk.trace.Qm.Coefficients() - cqo := pk.trace.Qo.Coefficients() - cqk := pk.trace.Qk.Coefficients() + cql := s.trace.Ql.Coefficients() + cqr := s.trace.Qr.Coefficients() + cqm := s.trace.Qm.Coefficients() + cqo := s.trace.Qo.Coefficients() + cqk := s.trace.Qk.Coefficients() var t, t0, t1 fr.Element diff --git a/backend/plonk/bw6-633/setup.go b/backend/plonk/bw6-633/setup.go index 81f2735a32..8aa342a41a 100644 --- a/backend/plonk/bw6-633/setup.go +++ b/backend/plonk/bw6-633/setup.go @@ -78,34 +78,14 @@ type Trace struct { S []int64 } -// ProvingKey stores the data needed to generate a proof: -// * the commitment scheme -// * ql, prepended with as many ones as they are public inputs -// * qr, qm, qo prepended with as many zeroes as there are public inputs. -// * qk, prepended with as many zeroes as public inputs, to be completed by the prover -// with the list of public inputs. -// * sigma_1, sigma_2, sigma_3 in both basis -// * the copy constraint permutation +// ProvingKey stores the data needed to generate a proof type ProvingKey struct { - // stores ql, qr, qm, qo, qk (-> to be completed by the prover) - // and s1, s2, s3. They are set in canonical basis before generating the proof, they will be used - // for computing the opening proofs (hence the canonical form). The canonical version - // of qk incomplete is used in the linearisation polynomial. - // The polynomials in trace are in canonical basis. - trace Trace - Kzg, KzgLagrange kzg.ProvingKey // Verifying Key is embedded into the proving key (needed by Prove) Vk *VerifyingKey - - // Domains used for the FFTs. - // Domain[0] = small Domain - // Domain[1] = big Domain - Domain [2]fft.Domain } -// TODO modify the signature to receive the SRS in Lagrange form (optional argument ?) func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *VerifyingKey, error) { var pk ProvingKey @@ -114,26 +94,26 @@ func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *Verifyin vk.CommitmentConstraintIndexes = internal.IntSliceToUint64Slice(spr.CommitmentInfo.CommitmentIndexes()) // step 0: set the fft domains - pk.initDomains(spr) - if pk.Domain[0].Cardinality < 2 { + domain := initFFTDomain(spr) + if domain.Cardinality < 2 { return nil, nil, fmt.Errorf("circuit has only %d constraints; unsupported by the current implementation", spr.GetNbConstraints()) } // check the size of the kzg srs. - if len(srs.Pk.G1) < (int(pk.Domain[0].Cardinality) + 3) { // + 3 for the kzg.Open of blinded poly - return nil, nil, fmt.Errorf("kzg srs is too small: got %d, need %d", len(srs.Pk.G1), pk.Domain[0].Cardinality+3) + if len(srs.Pk.G1) < (int(domain.Cardinality) + 3) { // + 3 for the kzg.Open of blinded poly + return nil, nil, fmt.Errorf("kzg srs is too small: got %d, need %d", len(srs.Pk.G1), domain.Cardinality+3) } // same for the lagrange form - if len(srsLagrange.Pk.G1) != int(pk.Domain[0].Cardinality) { - return nil, nil, fmt.Errorf("kzg srs lagrange is too small: got %d, need %d", len(srsLagrange.Pk.G1), pk.Domain[0].Cardinality) + if len(srsLagrange.Pk.G1) != int(domain.Cardinality) { + return nil, nil, fmt.Errorf("kzg srs lagrange is too small: got %d, need %d", len(srsLagrange.Pk.G1), domain.Cardinality) } // step 1: set the verifying key - pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) - vk.Size = pk.Domain[0].Cardinality + vk.CosetShift.Set(&domain.FrMultiplicativeGen) + vk.Size = domain.Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.Domain[0].Generator) + vk.Generator.Set(&domain.Generator) vk.NbPublicVariables = uint64(len(spr.Public)) pk.Kzg.G1 = srs.Pk.G1[:int(vk.Size)+3] @@ -141,22 +121,15 @@ func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *Verifyin vk.Kzg = srs.Vk // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis - BuildTrace(spr, &pk.trace) - // step 3: build the permutation and build the polynomials S1, S2, S3 to encode the permutation. // Note: at this stage, the permutation takes in account the placeholders - nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - buildPermutation(spr, &pk.trace, nbVariables) - s := computePermutationPolynomials(&pk.trace, &pk.Domain[0]) - pk.trace.S1 = s[0] - pk.trace.S2 = s[1] - pk.trace.S3 = s[2] + trace := NewTrace(spr, domain) // step 4: commit to s1, s2, s3, ql, qr, qm, qo, and (the incomplete version of) qk. // All the above polynomials are expressed in canonical basis afterwards. This is why // we save lqk before, because the prover needs to complete it in Lagrange form, and // then express it on the Lagrange coset basis. - if err := commitTrace(&pk.trace, &pk); err != nil { + if err := vk.commitTrace(trace, domain, pk.KzgLagrange); err != nil { return nil, nil, err } @@ -173,9 +146,13 @@ func (pk *ProvingKey) VerifyingKey() interface{} { return pk.Vk } -// BuildTrace fills the constant columns ql, qr, qm, qo, qk from the sparser1cs. -// Size is the size of the system that is nb_constraints+nb_public_variables -func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { +// NewTrace returns a new Trace object from the constraint system. +// It fills the constant columns ql, qr, qm, qo, qk, and qcp with the +// coefficients of the constraints. +// Size is the size of the system that is next power of 2 (nb_constraints+nb_public_variables) +// The permutation is also computed and stored in the Trace. +func NewTrace(spr *cs.SparseR1CS, domain *fft.Domain) *Trace { + var trace Trace nbConstraints := spr.GetNbConstraints() sizeSystem := uint64(nbConstraints + len(spr.Public)) @@ -211,85 +188,75 @@ func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} - pt.Ql = iop.NewPolynomial(&ql, lagReg) - pt.Qr = iop.NewPolynomial(&qr, lagReg) - pt.Qm = iop.NewPolynomial(&qm, lagReg) - pt.Qo = iop.NewPolynomial(&qo, lagReg) - pt.Qk = iop.NewPolynomial(&qk, lagReg) - pt.Qcp = make([]*iop.Polynomial, len(qcp)) + trace.Ql = iop.NewPolynomial(&ql, lagReg) + trace.Qr = iop.NewPolynomial(&qr, lagReg) + trace.Qm = iop.NewPolynomial(&qm, lagReg) + trace.Qo = iop.NewPolynomial(&qo, lagReg) + trace.Qk = iop.NewPolynomial(&qk, lagReg) + trace.Qcp = make([]*iop.Polynomial, len(qcp)) for i := range commitmentInfo { qcp[i] = make([]fr.Element, size) for _, committed := range commitmentInfo[i].Committed { qcp[i][offset+committed].SetOne() } - pt.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) + trace.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) } + + // build the permutation and build the polynomials S1, S2, S3 to encode the permutation. + // Note: at this stage, the permutation takes in account the placeholders + nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + buildPermutation(spr, &trace, nbVariables) + s := computePermutationPolynomials(&trace, domain) + trace.S1 = s[0] + trace.S2 = s[1] + trace.S3 = s[2] + + return &trace } // commitTrace commits to every polynomial in the trace, and put // the commitments int the verifying key. -func commitTrace(trace *Trace, pk *ProvingKey) error { - - trace.Ql.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qr.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qm.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qo.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qk.ToCanonical(&pk.Domain[0]).ToRegular() // -> qk is not complete - trace.S1.ToCanonical(&pk.Domain[0]).ToRegular() - trace.S2.ToCanonical(&pk.Domain[0]).ToRegular() - trace.S3.ToCanonical(&pk.Domain[0]).ToRegular() +func (vk *VerifyingKey) commitTrace(trace *Trace, domain *fft.Domain, srsPk kzg.ProvingKey) error { var err error - pk.Vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) + vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) for i := range trace.Qcp { - trace.Qcp[i].ToCanonical(&pk.Domain[0]).ToRegular() - if pk.Vk.Qcp[i], err = kzg.Commit(pk.trace.Qcp[i].Coefficients(), pk.Kzg); err != nil { + if vk.Qcp[i], err = kzg.Commit(trace.Qcp[i].Coefficients(), srsPk); err != nil { return err } } - if pk.Vk.Ql, err = kzg.Commit(pk.trace.Ql.Coefficients(), pk.Kzg); err != nil { + if vk.Ql, err = kzg.Commit(trace.Ql.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qr, err = kzg.Commit(pk.trace.Qr.Coefficients(), pk.Kzg); err != nil { + if vk.Qr, err = kzg.Commit(trace.Qr.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qm, err = kzg.Commit(pk.trace.Qm.Coefficients(), pk.Kzg); err != nil { + if vk.Qm, err = kzg.Commit(trace.Qm.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qo, err = kzg.Commit(pk.trace.Qo.Coefficients(), pk.Kzg); err != nil { + if vk.Qo, err = kzg.Commit(trace.Qo.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qk, err = kzg.Commit(pk.trace.Qk.Coefficients(), pk.Kzg); err != nil { + if vk.Qk, err = kzg.Commit(trace.Qk.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.S[0], err = kzg.Commit(pk.trace.S1.Coefficients(), pk.Kzg); err != nil { + if vk.S[0], err = kzg.Commit(trace.S1.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.S[1], err = kzg.Commit(pk.trace.S2.Coefficients(), pk.Kzg); err != nil { + if vk.S[1], err = kzg.Commit(trace.S2.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.S[2], err = kzg.Commit(pk.trace.S3.Coefficients(), pk.Kzg); err != nil { + if vk.S[2], err = kzg.Commit(trace.S3.Coefficients(), srsPk); err != nil { return err } return nil } -func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { - +func initFFTDomain(spr *cs.SparseR1CS) *fft.Domain { nbConstraints := spr.GetNbConstraints() sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints - pk.Domain[0] = *fft.NewDomain(sizeSystem) - - // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, - // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases - // except when n<6. - if sizeSystem < 6 { - pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) - } else { - pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) - } - + return fft.NewDomain(sizeSystem, fft.WithoutPrecompute()) } // buildPermutation builds the Permutation associated with a circuit. @@ -304,10 +271,10 @@ func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { // The permutation is encoded as a slice s of size 3*size(l), where the // i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab // like this: for i in tab: tab[i] = tab[permutation[i]] -func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { +func buildPermutation(spr *cs.SparseR1CS, trace *Trace, nbVariables int) { // nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - sizeSolution := len(pt.Ql.Coefficients()) + sizeSolution := len(trace.Ql.Coefficients()) sizePermutation := 3 * sizeSolution // init permutation @@ -357,13 +324,13 @@ func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { } } - pt.S = permutation + trace.S = permutation } // computePermutationPolynomials computes the LDE (Lagrange basis) of the permutation. // We let the permutation act on || u || u^{2}, split the result in 3 parts, // and interpolate each of the 3 parts on . -func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polynomial { +func computePermutationPolynomials(trace *Trace, domain *fft.Domain) [3]*iop.Polynomial { nbElmts := int(domain.Cardinality) @@ -377,9 +344,9 @@ func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polyno s2Canonical := make([]fr.Element, nbElmts) s3Canonical := make([]fr.Element, nbElmts) for i := 0; i < nbElmts; i++ { - s1Canonical[i].Set(&evaluationIDSmallDomain[pt.S[i]]) - s2Canonical[i].Set(&evaluationIDSmallDomain[pt.S[nbElmts+i]]) - s3Canonical[i].Set(&evaluationIDSmallDomain[pt.S[2*nbElmts+i]]) + s1Canonical[i].Set(&evaluationIDSmallDomain[trace.S[i]]) + s2Canonical[i].Set(&evaluationIDSmallDomain[trace.S[nbElmts+i]]) + s3Canonical[i].Set(&evaluationIDSmallDomain[trace.S[2*nbElmts+i]]) } lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} diff --git a/backend/plonk/bw6-761/marshal.go b/backend/plonk/bw6-761/marshal.go index 33c85beda5..f3d2f3417d 100644 --- a/backend/plonk/bw6-761/marshal.go +++ b/backend/plonk/bw6-761/marshal.go @@ -19,10 +19,6 @@ package plonk import ( curve "github.com/consensys/gnark-crypto/ecc/bw6-761" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" - - "errors" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/iop" "github.com/consensys/gnark-crypto/ecc/bw6-761/kzg" "io" ) @@ -116,19 +112,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, withCompression bool) (n int64, err e return } - // fft domains - n2, err := pk.Domain[0].WriteTo(w) - if err != nil { - return - } - n += n2 - - n2, err = pk.Domain[1].WriteTo(w) - if err != nil { - return - } - n += n2 - + var n2 int64 // KZG key if withCompression { n2, err = pk.Kzg.WriteTo(w) @@ -149,35 +133,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, withCompression bool) (n int64, err e } n += n2 - // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) - if len(pk.trace.S) != (3 * int(pk.Domain[0].Cardinality)) { - return n, errors.New("invalid permutation size, expected 3*domain cardinality") - } - - enc := curve.NewEncoder(w) - // note: type Polynomial, which is handled by default binary.Write(...) op and doesn't - // encode the size (nor does it convert from Montgomery to Regular form) - // so we explicitly transmit []fr.Element - toEncode := []interface{}{ - pk.trace.Ql.Coefficients(), - pk.trace.Qr.Coefficients(), - pk.trace.Qm.Coefficients(), - pk.trace.Qo.Coefficients(), - pk.trace.Qk.Coefficients(), - coefficients(pk.trace.Qcp), - pk.trace.S1.Coefficients(), - pk.trace.S2.Coefficients(), - pk.trace.S3.Coefficients(), - pk.trace.S, - } - - for _, v := range toEncode { - if err := enc.Encode(v); err != nil { - return n + enc.BytesWritten(), err - } - } - - return n + enc.BytesWritten(), nil + return n, nil } // ReadFrom reads from binary representation in r into ProvingKey @@ -197,18 +153,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err return n, err } - n2, err, chDomain0 := pk.Domain[0].AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - - n2, err, chDomain1 := pk.Domain[1].AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - + var n2 int64 if withSubgroupChecks { n2, err = pk.Kzg.ReadFrom(r) } else { @@ -224,98 +169,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err n2, err = pk.KzgLagrange.UnsafeReadFrom(r) } n += n2 - if err != nil { - return n, err - } - - pk.trace.S = make([]int64, 3*pk.Domain[0].Cardinality) - - dec := curve.NewDecoder(r) - - var ql, qr, qm, qo, qk, s1, s2, s3 []fr.Element - var qcp [][]fr.Element - - // TODO @gbotrel: this is a bit ugly, we should probably refactor this. - // The order of the variables is important, as it matches the order in which they are - // encoded in the WriteTo(...) method. - - // Note: instead of calling dec.Decode(...) for each of the above variables, - // we call AsyncReadFrom when possible which allows to consume bytes from the reader - // and perform the decoding in parallel - - type v struct { - data *fr.Vector - chErr chan error - } - - vectors := make([]v, 8) - vectors[0] = v{data: (*fr.Vector)(&ql)} - vectors[1] = v{data: (*fr.Vector)(&qr)} - vectors[2] = v{data: (*fr.Vector)(&qm)} - vectors[3] = v{data: (*fr.Vector)(&qo)} - vectors[4] = v{data: (*fr.Vector)(&qk)} - vectors[5] = v{data: (*fr.Vector)(&s1)} - vectors[6] = v{data: (*fr.Vector)(&s2)} - vectors[7] = v{data: (*fr.Vector)(&s3)} - - // read ql, qr, qm, qo, qk - for i := 0; i < 5; i++ { - n2, err, ch := vectors[i].data.AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - vectors[i].chErr = ch - } - - // read qcp - if err := dec.Decode(&qcp); err != nil { - return n + dec.BytesRead(), err - } - - // read lqk, s1, s2, s3 - for i := 5; i < 8; i++ { - n2, err, ch := vectors[i].data.AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - vectors[i].chErr = ch - } - - // read pk.Trace.S - if err := dec.Decode(&pk.trace.S); err != nil { - return n + dec.BytesRead(), err - } - - // wait for all AsyncReadFrom(...) to complete - for i := range vectors { - if err := <-vectors[i].chErr; err != nil { - return n, err - } - } - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - pk.trace.Ql = iop.NewPolynomial(&ql, canReg) - pk.trace.Qr = iop.NewPolynomial(&qr, canReg) - pk.trace.Qm = iop.NewPolynomial(&qm, canReg) - pk.trace.Qo = iop.NewPolynomial(&qo, canReg) - pk.trace.Qk = iop.NewPolynomial(&qk, canReg) - pk.trace.S1 = iop.NewPolynomial(&s1, canReg) - pk.trace.S2 = iop.NewPolynomial(&s2, canReg) - pk.trace.S3 = iop.NewPolynomial(&s3, canReg) - - pk.trace.Qcp = make([]*iop.Polynomial, len(qcp)) - for i := range qcp { - pk.trace.Qcp[i] = iop.NewPolynomial(&qcp[i], canReg) - } - - // wait for FFT to be precomputed - <-chDomain0 - <-chDomain1 - - return n + dec.BytesRead(), nil - + return n, err } // WriteTo writes binary encoding of VerifyingKey to w diff --git a/backend/plonk/bw6-761/marshal_test.go b/backend/plonk/bw6-761/marshal_test.go index bc460d7f9f..cc6aa3c406 100644 --- a/backend/plonk/bw6-761/marshal_test.go +++ b/backend/plonk/bw6-761/marshal_test.go @@ -20,9 +20,6 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bw6-761" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" - - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/iop" "github.com/consensys/gnark/io" "math/big" "math/rand" @@ -60,8 +57,6 @@ func (pk *ProvingKey) randomize() { var vk VerifyingKey vk.randomize() pk.Vk = &vk - pk.Domain[0] = *fft.NewDomain(32) - pk.Domain[1] = *fft.NewDomain(4 * 32) pk.Kzg.G1 = make([]curve.G1Affine, 32) pk.KzgLagrange.G1 = make([]curve.G1Affine, 32) @@ -70,36 +65,6 @@ func (pk *ProvingKey) randomize() { pk.KzgLagrange.G1[i] = randomG1Point() } - n := int(pk.Domain[0].Cardinality) - ql := randomScalars(n) - qr := randomScalars(n) - qm := randomScalars(n) - qo := randomScalars(n) - qk := randomScalars(n) - s1 := randomScalars(n) - s2 := randomScalars(n) - s3 := randomScalars(n) - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - pk.trace.Ql = iop.NewPolynomial(&ql, canReg) - pk.trace.Qr = iop.NewPolynomial(&qr, canReg) - pk.trace.Qm = iop.NewPolynomial(&qm, canReg) - pk.trace.Qo = iop.NewPolynomial(&qo, canReg) - pk.trace.Qk = iop.NewPolynomial(&qk, canReg) - pk.trace.S1 = iop.NewPolynomial(&s1, canReg) - pk.trace.S2 = iop.NewPolynomial(&s2, canReg) - pk.trace.S3 = iop.NewPolynomial(&s3, canReg) - - pk.trace.Qcp = make([]*iop.Polynomial, rand.Intn(4)) //#nosec G404 weak rng is fine here - for i := range pk.trace.Qcp { - qcp := randomScalars(rand.Intn(n / 4)) //#nosec G404 weak rng is fine here - pk.trace.Qcp[i] = iop.NewPolynomial(&qcp, canReg) - } - - pk.trace.S = make([]int64, 3*pk.Domain[0].Cardinality) - pk.trace.S[0] = -12 - pk.trace.S[len(pk.trace.S)-1] = 8888 - } func (vk *VerifyingKey) randomize() { diff --git a/backend/plonk/bw6-761/prove.go b/backend/plonk/bw6-761/prove.go index c433e6b122..10f2b7552d 100644 --- a/backend/plonk/bw6-761/prove.go +++ b/backend/plonk/bw6-761/prove.go @@ -138,9 +138,6 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts // solve constraints g.Go(instance.solveConstraints) - // compute numerator data - g.Go(instance.initComputeNumerator) - // complete qk g.Go(instance.completeQk) @@ -211,9 +208,6 @@ type instance struct { // challenges gamma, beta, alpha, zeta fr.Element - // compute numerator data - cres, twiddles0, cosetTableRev, twiddlesRev []fr.Element - // channel to wait for the steps chLRO, chQk, @@ -224,8 +218,11 @@ type instance struct { chZOpening, chLinearizedPolynomial, chFoldedH, - chNumeratorInit, chGammaBeta chan struct{} + + domain0, domain1 *fft.Domain + + trace *Trace } func newInstance(ctx context.Context, spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts *backend.ProverConfig) (*instance, error) { @@ -253,43 +250,31 @@ func newInstance(ctx context.Context, spr *cs.SparseR1CS, pk *ProvingKey, fullWi chLinearizedPolynomial: make(chan struct{}, 1), chFoldedH: make(chan struct{}, 1), chRestoreLRO: make(chan struct{}, 1), - chNumeratorInit: make(chan struct{}, 1), } s.initBSB22Commitments() s.setupGKRHints() s.x = make([]*iop.Polynomial, id_Qci+2*len(s.commitmentInfo)) - return &s, nil -} + // init fft domains + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints + s.domain0 = fft.NewDomain(sizeSystem) -func (s *instance) initComputeNumerator() error { - n := s.pk.Domain[0].Cardinality - s.cres = make([]fr.Element, s.pk.Domain[1].Cardinality) - s.twiddles0 = make([]fr.Element, n) - if n == 1 { - // edge case - s.twiddles0[0].SetOne() + // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, + // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases + // except when n<6. + if sizeSystem < 6 { + s.domain1 = fft.NewDomain(8*sizeSystem, fft.WithoutPrecompute()) } else { - copy(s.twiddles0, s.pk.Domain[0].Twiddles[0]) - for i := len(s.pk.Domain[0].Twiddles[0]); i < len(s.twiddles0); i++ { - s.twiddles0[i].Mul(&s.twiddles0[i-1], &s.twiddles0[1]) - } + s.domain1 = fft.NewDomain(4*sizeSystem, fft.WithoutPrecompute()) } + // TODO @gbotrel domain1 is used for only 1 FFT --> precomputing the twiddles + // and storing them in memory is costly given its size. --> do a FFT on the fly - cosetTable := s.pk.Domain[0].CosetTable - twiddles := s.pk.Domain[1].Twiddles[0][:n] - - s.cosetTableRev = make([]fr.Element, len(cosetTable)) - copy(s.cosetTableRev, cosetTable) - fft.BitReverse(s.cosetTableRev) - - s.twiddlesRev = make([]fr.Element, len(twiddles)) - copy(s.twiddlesRev, twiddles) - fft.BitReverse(s.twiddlesRev) - - close(s.chNumeratorInit) + // build trace + s.trace = NewTrace(spr, s.domain0) - return nil + return &s, nil } func (s *instance) initBlindingPolynomials() error { @@ -321,7 +306,7 @@ func (s *instance) bsb22Hint(_ *big.Int, ins, outs []*big.Int) error { res := &s.commitmentVal[commDepth] commitmentInfo := s.spr.CommitmentInfo.(constraint.PlonkCommitments)[commDepth] - committedValues := make([]fr.Element, s.pk.Domain[0].Cardinality) + committedValues := make([]fr.Element, s.domain0.Cardinality) offset := s.spr.GetNbPublicVariables() for i := range ins { committedValues[offset+commitmentInfo.Committed[i]].SetBigInt(ins[i]) @@ -336,7 +321,6 @@ func (s *instance) bsb22Hint(_ *big.Int, ins, outs []*big.Int) error { if s.proof.Bsb22Commitments[commDepth], err = kzg.Commit(s.cCommitments[commDepth].Coefficients(), s.pk.KzgLagrange); err != nil { return err } - s.cCommitments[commDepth].ToCanonical(&s.pk.Domain[0]).ToRegular() s.htfFunc.Write(s.proof.Bsb22Commitments[commDepth].Marshal()) hashBts := s.htfFunc.Sum(nil) @@ -395,7 +379,7 @@ func (s *instance) solveConstraints() error { } func (s *instance) completeQk() error { - qk := s.pk.trace.Qk.Clone().ToLagrange(&s.pk.Domain[0]).ToRegular() + qk := s.trace.Qk.Clone() qkCoeffs := qk.Coefficients() wWitness, ok := s.fullWitness.Vector().(fr.Vector) @@ -494,7 +478,7 @@ func (s *instance) commitToPolyAndBlinding(p, b *iop.Polynomial) (commit curve.G commit, err = kzg.Commit(p.Coefficients(), s.pk.KzgLagrange) // we add in the blinding contribution - n := int(s.pk.Domain[0].Cardinality) + n := int(s.domain0.Cardinality) cb := commitBlindingFactor(n, b, s.pk.Kzg) commit.Add(&commit, &cb) @@ -518,20 +502,19 @@ func (s *instance) deriveZeta() (err error) { // evaluateConstraints computes H func (s *instance) evaluateConstraints() (err error) { - // clone polys from the proving key. - s.x[id_Ql] = s.pk.trace.Ql.Clone() - s.x[id_Qr] = s.pk.trace.Qr.Clone() - s.x[id_Qm] = s.pk.trace.Qm.Clone() - s.x[id_Qo] = s.pk.trace.Qo.Clone() - s.x[id_S1] = s.pk.trace.S1.Clone() - s.x[id_S2] = s.pk.trace.S2.Clone() - s.x[id_S3] = s.pk.trace.S3.Clone() + s.x[id_Ql] = s.trace.Ql + s.x[id_Qr] = s.trace.Qr + s.x[id_Qm] = s.trace.Qm + s.x[id_Qo] = s.trace.Qo + s.x[id_S1] = s.trace.S1 + s.x[id_S2] = s.trace.S2 + s.x[id_S3] = s.trace.S3 for i := 0; i < len(s.commitmentInfo); i++ { - s.x[id_Qci+2*i] = s.pk.trace.Qcp[i].Clone() + s.x[id_Qci+2*i] = s.trace.Qcp[i] } - n := s.pk.Domain[0].Cardinality + n := s.domain0.Cardinality lone := make([]fr.Element, n) lone[0].SetOne() @@ -543,7 +526,7 @@ func (s *instance) evaluateConstraints() (err error) { } for i := 0; i < len(s.commitmentInfo); i++ { - s.x[id_Qci+2*i+1] = s.cCommitments[i].Clone() + s.x[id_Qci+2*i+1] = s.cCommitments[i] } // wait for Z to be committed or context done @@ -571,7 +554,7 @@ func (s *instance) evaluateConstraints() (err error) { return err } - s.h, err = divideByXMinusOne(numerator, [2]*fft.Domain{&s.pk.Domain[0], &s.pk.Domain[1]}) + s.h, err = divideByXMinusOne(numerator, [2]*fft.Domain{s.domain0, s.domain1}) if err != nil { return err } @@ -613,11 +596,11 @@ func (s *instance) buildRatioCopyConstraint() (err error) { s.x[id_R], s.x[id_O], }, - s.pk.trace.S, + s.trace.S, s.beta, s.gamma, iop.Form{Basis: iop.Lagrange, Layout: iop.Regular}, - &s.pk.Domain[0], + s.domain0, ) if err != nil { return err @@ -652,17 +635,17 @@ func (s *instance) openZ() (err error) { } func (s *instance) h1() []fr.Element { - h1 := s.h.Coefficients()[:s.pk.Domain[0].Cardinality+2] + h1 := s.h.Coefficients()[:s.domain0.Cardinality+2] return h1 } func (s *instance) h2() []fr.Element { - h2 := s.h.Coefficients()[s.pk.Domain[0].Cardinality+2 : 2*(s.pk.Domain[0].Cardinality+2)] + h2 := s.h.Coefficients()[s.domain0.Cardinality+2 : 2*(s.domain0.Cardinality+2)] return h2 } func (s *instance) h3() []fr.Element { - h3 := s.h.Coefficients()[2*(s.pk.Domain[0].Cardinality+2) : 3*(s.pk.Domain[0].Cardinality+2)] + h3 := s.h.Coefficients()[2*(s.domain0.Cardinality+2) : 3*(s.domain0.Cardinality+2)] return h3 } @@ -675,7 +658,7 @@ func (s *instance) foldH() error { case <-s.chH: } var n big.Int - n.SetUint64(s.pk.Domain[0].Cardinality + 2) + n.SetUint64(s.domain0.Cardinality + 2) var zetaPowerNplusTwo fr.Element zetaPowerNplusTwo.Exp(s.zeta, &n) @@ -691,7 +674,7 @@ func (s *instance) foldH() error { h2 := s.h2() s.foldedH = s.h3() - for i := 0; i < int(s.pk.Domain[0].Cardinality)+2; i++ { + for i := 0; i < int(s.domain0.Cardinality)+2; i++ { s.foldedH[i]. Mul(&s.foldedH[i], &zetaPowerNplusTwo). Add(&s.foldedH[i], &h2[i]). @@ -720,7 +703,7 @@ func (s *instance) computeLinearizedPolynomial() error { for i := 0; i < len(s.commitmentInfo); i++ { go func(i int) { - qcpzeta[i] = s.pk.trace.Qcp[i].Evaluate(s.zeta) + qcpzeta[i] = s.trace.Qcp[i].Evaluate(s.zeta) wg.Done() }(i) } @@ -750,7 +733,7 @@ func (s *instance) computeLinearizedPolynomial() error { wg.Wait() - s.linearizedPolynomial = computeLinearizedPolynomial( + s.linearizedPolynomial = s.innerComputeLinearizedPoly( blzeta, brzeta, bozeta, @@ -775,9 +758,6 @@ func (s *instance) computeLinearizedPolynomial() error { } func (s *instance) batchOpening() error { - polysQcp := coefficients(s.pk.trace.Qcp) - polysToOpen := make([][]fr.Element, 7+len(polysQcp)) - copy(polysToOpen[7:], polysQcp) // wait for LRO to be committed (or ctx.Done()) select { @@ -800,13 +780,17 @@ func (s *instance) batchOpening() error { case <-s.chLinearizedPolynomial: } + polysQcp := coefficients(s.trace.Qcp) + polysToOpen := make([][]fr.Element, 7+len(polysQcp)) + copy(polysToOpen[7:], polysQcp) + polysToOpen[0] = s.foldedH polysToOpen[1] = s.linearizedPolynomial polysToOpen[2] = getBlindedCoefficients(s.x[id_L], s.bp[id_Bl]) polysToOpen[3] = getBlindedCoefficients(s.x[id_R], s.bp[id_Br]) polysToOpen[4] = getBlindedCoefficients(s.x[id_O], s.bp[id_Bo]) - polysToOpen[5] = s.pk.trace.S1.Coefficients() - polysToOpen[6] = s.pk.trace.S2.Coefficients() + polysToOpen[5] = s.trace.S1.Coefficients() + polysToOpen[6] = s.trace.S2.Coefficients() digestsToOpen := make([]curve.G1Affine, len(s.pk.Vk.Qcp)+7) copy(digestsToOpen[7:], s.pk.Vk.Qcp) @@ -835,6 +819,23 @@ func (s *instance) batchOpening() error { // evaluate the full set of constraints, all polynomials in x are back in // canonical regular form at the end func (s *instance) computeNumerator() (*iop.Polynomial, error) { + // init vectors that are used multiple times throughout the computation + n := s.domain0.Cardinality + twiddles0 := make([]fr.Element, n) + if n == 1 { + // edge case + twiddles0[0].SetOne() + } else { + twiddles, err := s.domain0.Twiddles() + if err != nil { + return nil, err + } + copy(twiddles0, twiddles[0]) + w := twiddles0[1] + for i := len(twiddles[0]); i < len(twiddles0); i++ { + twiddles0[i].Mul(&twiddles0[i-1], &w) + } + } // wait for chQk to be closed (or ctx.Done()) select { @@ -843,8 +844,6 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { case <-s.chQk: } - n := s.pk.Domain[0].Cardinality - nbBsbGates := (len(s.x) - id_Qci + 1) >> 1 gateConstraint := func(u ...fr.Element) fr.Element { @@ -867,7 +866,7 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { } var cs, css fr.Element - cs.Set(&s.pk.Domain[1].FrMultiplicativeGen) + cs.Set(&s.domain1.FrMultiplicativeGen) css.Square(&cs) orderingConstraint := func(u ...fr.Element) fr.Element { @@ -899,11 +898,11 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { return res } - rho := int(s.pk.Domain[1].Cardinality / n) + rho := int(s.domain1.Cardinality / n) shifters := make([]fr.Element, rho) - shifters[0].Set(&s.pk.Domain[1].FrMultiplicativeGen) + shifters[0].Set(&s.domain1.FrMultiplicativeGen) for i := 1; i < rho; i++ { - shifters[i].Set(&s.pk.Domain[1].Generator) + shifters[i].Set(&s.domain1.Generator) } // stores the current coset shifter @@ -914,14 +913,13 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { one.SetOne() bn := big.NewInt(int64(n)) - // wait for init go routine - <-s.chNumeratorInit - - cosetTable := s.pk.Domain[0].CosetTable - twiddles := s.pk.Domain[1].Twiddles[0][:n] + cosetTable, err := s.domain0.CosetTable() + if err != nil { + return nil, err + } // init the result polynomial & buffer - cres := s.cres + cres := make([]fr.Element, s.domain1.Cardinality) buf := make([]fr.Element, n) var wgBuf sync.WaitGroup @@ -933,17 +931,17 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { // blind L, R, O, Z, ZS var y fr.Element - y = s.bp[id_Bl].Evaluate(s.twiddles0[i]) + y = s.bp[id_Bl].Evaluate(twiddles0[i]) u[id_L].Add(&u[id_L], &y) - y = s.bp[id_Br].Evaluate(s.twiddles0[i]) + y = s.bp[id_Br].Evaluate(twiddles0[i]) u[id_R].Add(&u[id_R], &y) - y = s.bp[id_Bo].Evaluate(s.twiddles0[i]) + y = s.bp[id_Bo].Evaluate(twiddles0[i]) u[id_O].Add(&u[id_O], &y) - y = s.bp[id_Bz].Evaluate(s.twiddles0[i]) + y = s.bp[id_Bz].Evaluate(twiddles0[i]) u[id_Z].Add(&u[id_Z], &y) // ZS is shifted by 1; need to get correct twiddle - y = s.bp[id_Bz].Evaluate(s.twiddles0[(i+1)%int(n)]) + y = s.bp[id_Bz].Evaluate(twiddles0[(i+1)%int(n)]) u[id_ZS].Add(&u[id_ZS], &y) a := gateConstraint(u...) @@ -953,28 +951,15 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { return c } - // select the correct scaling vector to scale by shifter[i] - selectScalingVector := func(i int, l iop.Layout) []fr.Element { - var w []fr.Element - if i == 0 { - if l == iop.Regular { - w = cosetTable - } else { - w = s.cosetTableRev - } - } else { - if l == iop.Regular { - w = twiddles - } else { - w = s.twiddlesRev - } - } - return w - } + // for the first iteration, the scalingVector is the coset table + scalingVector := cosetTable + scalingVectorRev := make([]fr.Element, len(cosetTable)) + copy(scalingVectorRev, cosetTable) + fft.BitReverse(scalingVectorRev) // pre-computed to compute the bit reverse index // of the result polynomial - m := uint64(s.pk.Domain[1].Cardinality) + m := uint64(s.domain1.Cardinality) mm := uint64(64 - bits.TrailingZeros64(m)) for i := 0; i < rho; i++ { @@ -991,6 +976,17 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { acc.Mul(&acc, &shifters[i]) } } + if i == 1 { + // we have to update the scalingVector; instead of scaling by + // cosets we scale by the twiddles of the large domain. + w := s.domain1.Generator + scalingVector = make([]fr.Element, n) + fft.BuildExpTable(w, scalingVector) + + // reuse memory + copy(scalingVectorRev, scalingVector) + fft.BitReverse(scalingVectorRev) + } // we do **a lot** of FFT here, but on the small domain. // note that for all the polynomials in the proving key @@ -1000,10 +996,15 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { batchApply(s.x, func(p *iop.Polynomial) { nbTasks := calculateNbTasks(len(s.x)-1) * 2 // shift polynomials to be in the correct coset - p.ToCanonical(&s.pk.Domain[0], nbTasks) + p.ToCanonical(s.domain0, nbTasks) // scale by shifter[i] - w := selectScalingVector(i, p.Layout) + var w []fr.Element + if p.Layout == iop.Regular { + w = scalingVector + } else { + w = scalingVectorRev + } cp := p.Coefficients() utils.Parallelize(len(cp), func(start, end int) { @@ -1013,7 +1014,7 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { }, nbTasks) // fft in the correct coset - p.ToLagrange(&s.pk.Domain[0], nbTasks).ToRegular() + p.ToLagrange(s.domain0, nbTasks).ToRegular() }) wgBuf.Wait() @@ -1046,9 +1047,10 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { // scale everything back go func() { - for i := id_ZS; i < len(s.x); i++ { - s.x[i] = nil - } + s.x[id_ID] = nil + s.x[id_LOne] = nil + s.x[id_ZS] = nil + s.x[id_Qk] = nil var cs fr.Element cs.Set(&shifters[0]) @@ -1057,8 +1059,11 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { } cs.Inverse(&cs) - batchApply(s.x[:id_ZS], func(p *iop.Polynomial) { - p.ToCanonical(&s.pk.Domain[0], 8).ToRegular() + batchApply(s.x, func(p *iop.Polynomial) { + if p == nil { + return + } + p.ToCanonical(s.domain0, 8).ToRegular() scalePowers(p, cs) }) @@ -1275,7 +1280,7 @@ func evaluateXnMinusOneDomainBigCoset(domains [2]*fft.Domain) []fr.Element { return res } -// computeLinearizedPolynomial computes the linearized polynomial in canonical basis. +// innerComputeLinearizedPoly computes the linearized polynomial in canonical basis. // The purpose is to commit and open all in one ql, qr, qm, qo, qk. // * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta // * z is the permutation polynomial, zu is Z(μX), the shifted version of Z @@ -1286,8 +1291,8 @@ func evaluateXnMinusOneDomainBigCoset(domains [2]*fft.Domain) []fr.Element { // α²*L₁(ζ)*Z(X) // + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) // + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) -func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, qcpZeta, blindedZCanonical []fr.Element, pi2Canonical [][]fr.Element, pk *ProvingKey) []fr.Element { - +func (s *instance) innerComputeLinearizedPoly(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, qcpZeta, blindedZCanonical []fr.Element, pi2Canonical [][]fr.Element, pk *ProvingKey) []fr.Element { + // TODO @gbotrel rename // first part: individual constraints var rl fr.Element rl.Mul(&rZeta, &lZeta) @@ -1297,12 +1302,12 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, var s1, s2 fr.Element chS1 := make(chan struct{}, 1) go func() { - s1 = pk.trace.S1.Evaluate(zeta) // s1(ζ) + s1 = s.trace.S1.Evaluate(zeta) // s1(ζ) s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) close(chS1) }() // ps2 := iop.NewPolynomial(&pk.S2Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - tmp := pk.trace.S2.Evaluate(zeta) // s2(ζ) + tmp := s.trace.S2.Evaluate(zeta) // s2(ζ) tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) @@ -1321,7 +1326,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, // third part L₁(ζ)*α²*Z var lagrangeZeta, one, den, frNbElmt fr.Element one.SetOne() - nbElmt := int64(pk.Domain[0].Cardinality) + nbElmt := int64(s.domain0.Cardinality) lagrangeZeta.Set(&zeta). Exp(lagrangeZeta, big.NewInt(nbElmt)). Sub(&lagrangeZeta, &one) @@ -1331,17 +1336,19 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) Mul(&lagrangeZeta, &alpha). Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) + Mul(&lagrangeZeta, &s.domain0.CardinalityInv) // (1/n)*α²*L₁(ζ) + + s3canonical := s.trace.S3.Coefficients() - s3canonical := pk.trace.S3.Coefficients() + s.trace.Qk.ToCanonical(s.domain0).ToRegular() utils.Parallelize(len(blindedZCanonical), func(start, end int) { - cql := pk.trace.Ql.Coefficients() - cqr := pk.trace.Qr.Coefficients() - cqm := pk.trace.Qm.Coefficients() - cqo := pk.trace.Qo.Coefficients() - cqk := pk.trace.Qk.Coefficients() + cql := s.trace.Ql.Coefficients() + cqr := s.trace.Qr.Coefficients() + cqm := s.trace.Qm.Coefficients() + cqo := s.trace.Qo.Coefficients() + cqk := s.trace.Qk.Coefficients() var t, t0, t1 fr.Element diff --git a/backend/plonk/bw6-761/setup.go b/backend/plonk/bw6-761/setup.go index cd54cf19e7..9764e5a796 100644 --- a/backend/plonk/bw6-761/setup.go +++ b/backend/plonk/bw6-761/setup.go @@ -78,34 +78,14 @@ type Trace struct { S []int64 } -// ProvingKey stores the data needed to generate a proof: -// * the commitment scheme -// * ql, prepended with as many ones as they are public inputs -// * qr, qm, qo prepended with as many zeroes as there are public inputs. -// * qk, prepended with as many zeroes as public inputs, to be completed by the prover -// with the list of public inputs. -// * sigma_1, sigma_2, sigma_3 in both basis -// * the copy constraint permutation +// ProvingKey stores the data needed to generate a proof type ProvingKey struct { - // stores ql, qr, qm, qo, qk (-> to be completed by the prover) - // and s1, s2, s3. They are set in canonical basis before generating the proof, they will be used - // for computing the opening proofs (hence the canonical form). The canonical version - // of qk incomplete is used in the linearisation polynomial. - // The polynomials in trace are in canonical basis. - trace Trace - Kzg, KzgLagrange kzg.ProvingKey // Verifying Key is embedded into the proving key (needed by Prove) Vk *VerifyingKey - - // Domains used for the FFTs. - // Domain[0] = small Domain - // Domain[1] = big Domain - Domain [2]fft.Domain } -// TODO modify the signature to receive the SRS in Lagrange form (optional argument ?) func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *VerifyingKey, error) { var pk ProvingKey @@ -114,26 +94,26 @@ func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *Verifyin vk.CommitmentConstraintIndexes = internal.IntSliceToUint64Slice(spr.CommitmentInfo.CommitmentIndexes()) // step 0: set the fft domains - pk.initDomains(spr) - if pk.Domain[0].Cardinality < 2 { + domain := initFFTDomain(spr) + if domain.Cardinality < 2 { return nil, nil, fmt.Errorf("circuit has only %d constraints; unsupported by the current implementation", spr.GetNbConstraints()) } // check the size of the kzg srs. - if len(srs.Pk.G1) < (int(pk.Domain[0].Cardinality) + 3) { // + 3 for the kzg.Open of blinded poly - return nil, nil, fmt.Errorf("kzg srs is too small: got %d, need %d", len(srs.Pk.G1), pk.Domain[0].Cardinality+3) + if len(srs.Pk.G1) < (int(domain.Cardinality) + 3) { // + 3 for the kzg.Open of blinded poly + return nil, nil, fmt.Errorf("kzg srs is too small: got %d, need %d", len(srs.Pk.G1), domain.Cardinality+3) } // same for the lagrange form - if len(srsLagrange.Pk.G1) != int(pk.Domain[0].Cardinality) { - return nil, nil, fmt.Errorf("kzg srs lagrange is too small: got %d, need %d", len(srsLagrange.Pk.G1), pk.Domain[0].Cardinality) + if len(srsLagrange.Pk.G1) != int(domain.Cardinality) { + return nil, nil, fmt.Errorf("kzg srs lagrange is too small: got %d, need %d", len(srsLagrange.Pk.G1), domain.Cardinality) } // step 1: set the verifying key - pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) - vk.Size = pk.Domain[0].Cardinality + vk.CosetShift.Set(&domain.FrMultiplicativeGen) + vk.Size = domain.Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.Domain[0].Generator) + vk.Generator.Set(&domain.Generator) vk.NbPublicVariables = uint64(len(spr.Public)) pk.Kzg.G1 = srs.Pk.G1[:int(vk.Size)+3] @@ -141,22 +121,15 @@ func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *Verifyin vk.Kzg = srs.Vk // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis - BuildTrace(spr, &pk.trace) - // step 3: build the permutation and build the polynomials S1, S2, S3 to encode the permutation. // Note: at this stage, the permutation takes in account the placeholders - nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - buildPermutation(spr, &pk.trace, nbVariables) - s := computePermutationPolynomials(&pk.trace, &pk.Domain[0]) - pk.trace.S1 = s[0] - pk.trace.S2 = s[1] - pk.trace.S3 = s[2] + trace := NewTrace(spr, domain) // step 4: commit to s1, s2, s3, ql, qr, qm, qo, and (the incomplete version of) qk. // All the above polynomials are expressed in canonical basis afterwards. This is why // we save lqk before, because the prover needs to complete it in Lagrange form, and // then express it on the Lagrange coset basis. - if err := commitTrace(&pk.trace, &pk); err != nil { + if err := vk.commitTrace(trace, domain, pk.KzgLagrange); err != nil { return nil, nil, err } @@ -173,9 +146,13 @@ func (pk *ProvingKey) VerifyingKey() interface{} { return pk.Vk } -// BuildTrace fills the constant columns ql, qr, qm, qo, qk from the sparser1cs. -// Size is the size of the system that is nb_constraints+nb_public_variables -func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { +// NewTrace returns a new Trace object from the constraint system. +// It fills the constant columns ql, qr, qm, qo, qk, and qcp with the +// coefficients of the constraints. +// Size is the size of the system that is next power of 2 (nb_constraints+nb_public_variables) +// The permutation is also computed and stored in the Trace. +func NewTrace(spr *cs.SparseR1CS, domain *fft.Domain) *Trace { + var trace Trace nbConstraints := spr.GetNbConstraints() sizeSystem := uint64(nbConstraints + len(spr.Public)) @@ -211,85 +188,75 @@ func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} - pt.Ql = iop.NewPolynomial(&ql, lagReg) - pt.Qr = iop.NewPolynomial(&qr, lagReg) - pt.Qm = iop.NewPolynomial(&qm, lagReg) - pt.Qo = iop.NewPolynomial(&qo, lagReg) - pt.Qk = iop.NewPolynomial(&qk, lagReg) - pt.Qcp = make([]*iop.Polynomial, len(qcp)) + trace.Ql = iop.NewPolynomial(&ql, lagReg) + trace.Qr = iop.NewPolynomial(&qr, lagReg) + trace.Qm = iop.NewPolynomial(&qm, lagReg) + trace.Qo = iop.NewPolynomial(&qo, lagReg) + trace.Qk = iop.NewPolynomial(&qk, lagReg) + trace.Qcp = make([]*iop.Polynomial, len(qcp)) for i := range commitmentInfo { qcp[i] = make([]fr.Element, size) for _, committed := range commitmentInfo[i].Committed { qcp[i][offset+committed].SetOne() } - pt.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) + trace.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) } + + // build the permutation and build the polynomials S1, S2, S3 to encode the permutation. + // Note: at this stage, the permutation takes in account the placeholders + nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + buildPermutation(spr, &trace, nbVariables) + s := computePermutationPolynomials(&trace, domain) + trace.S1 = s[0] + trace.S2 = s[1] + trace.S3 = s[2] + + return &trace } // commitTrace commits to every polynomial in the trace, and put // the commitments int the verifying key. -func commitTrace(trace *Trace, pk *ProvingKey) error { - - trace.Ql.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qr.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qm.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qo.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qk.ToCanonical(&pk.Domain[0]).ToRegular() // -> qk is not complete - trace.S1.ToCanonical(&pk.Domain[0]).ToRegular() - trace.S2.ToCanonical(&pk.Domain[0]).ToRegular() - trace.S3.ToCanonical(&pk.Domain[0]).ToRegular() +func (vk *VerifyingKey) commitTrace(trace *Trace, domain *fft.Domain, srsPk kzg.ProvingKey) error { var err error - pk.Vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) + vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) for i := range trace.Qcp { - trace.Qcp[i].ToCanonical(&pk.Domain[0]).ToRegular() - if pk.Vk.Qcp[i], err = kzg.Commit(pk.trace.Qcp[i].Coefficients(), pk.Kzg); err != nil { + if vk.Qcp[i], err = kzg.Commit(trace.Qcp[i].Coefficients(), srsPk); err != nil { return err } } - if pk.Vk.Ql, err = kzg.Commit(pk.trace.Ql.Coefficients(), pk.Kzg); err != nil { + if vk.Ql, err = kzg.Commit(trace.Ql.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qr, err = kzg.Commit(pk.trace.Qr.Coefficients(), pk.Kzg); err != nil { + if vk.Qr, err = kzg.Commit(trace.Qr.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qm, err = kzg.Commit(pk.trace.Qm.Coefficients(), pk.Kzg); err != nil { + if vk.Qm, err = kzg.Commit(trace.Qm.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qo, err = kzg.Commit(pk.trace.Qo.Coefficients(), pk.Kzg); err != nil { + if vk.Qo, err = kzg.Commit(trace.Qo.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qk, err = kzg.Commit(pk.trace.Qk.Coefficients(), pk.Kzg); err != nil { + if vk.Qk, err = kzg.Commit(trace.Qk.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.S[0], err = kzg.Commit(pk.trace.S1.Coefficients(), pk.Kzg); err != nil { + if vk.S[0], err = kzg.Commit(trace.S1.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.S[1], err = kzg.Commit(pk.trace.S2.Coefficients(), pk.Kzg); err != nil { + if vk.S[1], err = kzg.Commit(trace.S2.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.S[2], err = kzg.Commit(pk.trace.S3.Coefficients(), pk.Kzg); err != nil { + if vk.S[2], err = kzg.Commit(trace.S3.Coefficients(), srsPk); err != nil { return err } return nil } -func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { - +func initFFTDomain(spr *cs.SparseR1CS) *fft.Domain { nbConstraints := spr.GetNbConstraints() sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints - pk.Domain[0] = *fft.NewDomain(sizeSystem) - - // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, - // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases - // except when n<6. - if sizeSystem < 6 { - pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) - } else { - pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) - } - + return fft.NewDomain(sizeSystem, fft.WithoutPrecompute()) } // buildPermutation builds the Permutation associated with a circuit. @@ -304,10 +271,10 @@ func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { // The permutation is encoded as a slice s of size 3*size(l), where the // i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab // like this: for i in tab: tab[i] = tab[permutation[i]] -func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { +func buildPermutation(spr *cs.SparseR1CS, trace *Trace, nbVariables int) { // nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - sizeSolution := len(pt.Ql.Coefficients()) + sizeSolution := len(trace.Ql.Coefficients()) sizePermutation := 3 * sizeSolution // init permutation @@ -357,13 +324,13 @@ func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { } } - pt.S = permutation + trace.S = permutation } // computePermutationPolynomials computes the LDE (Lagrange basis) of the permutation. // We let the permutation act on || u || u^{2}, split the result in 3 parts, // and interpolate each of the 3 parts on . -func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polynomial { +func computePermutationPolynomials(trace *Trace, domain *fft.Domain) [3]*iop.Polynomial { nbElmts := int(domain.Cardinality) @@ -377,9 +344,9 @@ func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polyno s2Canonical := make([]fr.Element, nbElmts) s3Canonical := make([]fr.Element, nbElmts) for i := 0; i < nbElmts; i++ { - s1Canonical[i].Set(&evaluationIDSmallDomain[pt.S[i]]) - s2Canonical[i].Set(&evaluationIDSmallDomain[pt.S[nbElmts+i]]) - s3Canonical[i].Set(&evaluationIDSmallDomain[pt.S[2*nbElmts+i]]) + s1Canonical[i].Set(&evaluationIDSmallDomain[trace.S[i]]) + s2Canonical[i].Set(&evaluationIDSmallDomain[trace.S[nbElmts+i]]) + s3Canonical[i].Set(&evaluationIDSmallDomain[trace.S[2*nbElmts+i]]) } lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} diff --git a/backend/plonk/plonk_test.go b/backend/plonk/plonk_test.go index 5d1fa1207e..55d1fb104f 100644 --- a/backend/plonk/plonk_test.go +++ b/backend/plonk/plonk_test.go @@ -290,7 +290,7 @@ func referenceCircuit(curve ecc.ID) (constraint.ConstraintSystem, frontend.Circu expectedY.Exp(expectedY, exp, curve.ScalarField()) good.Y = expectedY - srs, srsLagrange, err := unsafekzg.NewSRS(ccs) + srs, srsLagrange, err := unsafekzg.NewSRS(ccs, unsafekzg.WithFSCache()) if err != nil { panic(err) } diff --git a/backend/plonkfri/bls12-377/prove.go b/backend/plonkfri/bls12-377/prove.go index 4e18f1a2b2..82144641b2 100644 --- a/backend/plonkfri/bls12-377/prove.go +++ b/backend/plonkfri/bls12-377/prove.go @@ -428,9 +428,13 @@ func fftBigCosetWOBitReverse(poly []fr.Element, domainBig *fft.Domain) []fr.Elem // we copy poly in res and scale by coset here // to avoid FFT scaling on domainBig.Cardinality (res is very sparse) + cosetTable, err := domainBig.CosetTable() + if err != nil { + panic(err) + } utils.Parallelize(len(poly), func(start, end int) { for i := start; i < end; i++ { - res[i].Mul(&poly[i], &domainBig.CosetTable[i]) + res[i].Mul(&poly[i], &cosetTable[i]) } }, runtime.NumCPU()/2) domainBig.FFT(res, fft.DIF) diff --git a/backend/plonkfri/bls12-381/prove.go b/backend/plonkfri/bls12-381/prove.go index a168b3a21c..0a9288bea4 100644 --- a/backend/plonkfri/bls12-381/prove.go +++ b/backend/plonkfri/bls12-381/prove.go @@ -428,9 +428,13 @@ func fftBigCosetWOBitReverse(poly []fr.Element, domainBig *fft.Domain) []fr.Elem // we copy poly in res and scale by coset here // to avoid FFT scaling on domainBig.Cardinality (res is very sparse) + cosetTable, err := domainBig.CosetTable() + if err != nil { + panic(err) + } utils.Parallelize(len(poly), func(start, end int) { for i := start; i < end; i++ { - res[i].Mul(&poly[i], &domainBig.CosetTable[i]) + res[i].Mul(&poly[i], &cosetTable[i]) } }, runtime.NumCPU()/2) domainBig.FFT(res, fft.DIF) diff --git a/backend/plonkfri/bls24-315/prove.go b/backend/plonkfri/bls24-315/prove.go index c1ef342215..5f58af944e 100644 --- a/backend/plonkfri/bls24-315/prove.go +++ b/backend/plonkfri/bls24-315/prove.go @@ -428,9 +428,13 @@ func fftBigCosetWOBitReverse(poly []fr.Element, domainBig *fft.Domain) []fr.Elem // we copy poly in res and scale by coset here // to avoid FFT scaling on domainBig.Cardinality (res is very sparse) + cosetTable, err := domainBig.CosetTable() + if err != nil { + panic(err) + } utils.Parallelize(len(poly), func(start, end int) { for i := start; i < end; i++ { - res[i].Mul(&poly[i], &domainBig.CosetTable[i]) + res[i].Mul(&poly[i], &cosetTable[i]) } }, runtime.NumCPU()/2) domainBig.FFT(res, fft.DIF) diff --git a/backend/plonkfri/bls24-317/prove.go b/backend/plonkfri/bls24-317/prove.go index 5cecfe60ba..5f761a382e 100644 --- a/backend/plonkfri/bls24-317/prove.go +++ b/backend/plonkfri/bls24-317/prove.go @@ -428,9 +428,13 @@ func fftBigCosetWOBitReverse(poly []fr.Element, domainBig *fft.Domain) []fr.Elem // we copy poly in res and scale by coset here // to avoid FFT scaling on domainBig.Cardinality (res is very sparse) + cosetTable, err := domainBig.CosetTable() + if err != nil { + panic(err) + } utils.Parallelize(len(poly), func(start, end int) { for i := start; i < end; i++ { - res[i].Mul(&poly[i], &domainBig.CosetTable[i]) + res[i].Mul(&poly[i], &cosetTable[i]) } }, runtime.NumCPU()/2) domainBig.FFT(res, fft.DIF) diff --git a/backend/plonkfri/bn254/prove.go b/backend/plonkfri/bn254/prove.go index cc3a9e6288..ee38864f28 100644 --- a/backend/plonkfri/bn254/prove.go +++ b/backend/plonkfri/bn254/prove.go @@ -428,9 +428,13 @@ func fftBigCosetWOBitReverse(poly []fr.Element, domainBig *fft.Domain) []fr.Elem // we copy poly in res and scale by coset here // to avoid FFT scaling on domainBig.Cardinality (res is very sparse) + cosetTable, err := domainBig.CosetTable() + if err != nil { + panic(err) + } utils.Parallelize(len(poly), func(start, end int) { for i := start; i < end; i++ { - res[i].Mul(&poly[i], &domainBig.CosetTable[i]) + res[i].Mul(&poly[i], &cosetTable[i]) } }, runtime.NumCPU()/2) domainBig.FFT(res, fft.DIF) diff --git a/backend/plonkfri/bw6-633/prove.go b/backend/plonkfri/bw6-633/prove.go index 9fb899b67e..b0c2a76d7e 100644 --- a/backend/plonkfri/bw6-633/prove.go +++ b/backend/plonkfri/bw6-633/prove.go @@ -428,9 +428,13 @@ func fftBigCosetWOBitReverse(poly []fr.Element, domainBig *fft.Domain) []fr.Elem // we copy poly in res and scale by coset here // to avoid FFT scaling on domainBig.Cardinality (res is very sparse) + cosetTable, err := domainBig.CosetTable() + if err != nil { + panic(err) + } utils.Parallelize(len(poly), func(start, end int) { for i := start; i < end; i++ { - res[i].Mul(&poly[i], &domainBig.CosetTable[i]) + res[i].Mul(&poly[i], &cosetTable[i]) } }, runtime.NumCPU()/2) domainBig.FFT(res, fft.DIF) diff --git a/backend/plonkfri/bw6-761/prove.go b/backend/plonkfri/bw6-761/prove.go index 38d74d0cc8..3ca4fd5413 100644 --- a/backend/plonkfri/bw6-761/prove.go +++ b/backend/plonkfri/bw6-761/prove.go @@ -428,9 +428,13 @@ func fftBigCosetWOBitReverse(poly []fr.Element, domainBig *fft.Domain) []fr.Elem // we copy poly in res and scale by coset here // to avoid FFT scaling on domainBig.Cardinality (res is very sparse) + cosetTable, err := domainBig.CosetTable() + if err != nil { + panic(err) + } utils.Parallelize(len(poly), func(start, end int) { for i := start; i < end; i++ { - res[i].Mul(&poly[i], &domainBig.CosetTable[i]) + res[i].Mul(&poly[i], &cosetTable[i]) } }, runtime.NumCPU()/2) domainBig.FFT(res, fft.DIF) diff --git a/go.mod b/go.mod index 8e32ac0d5c..ff34417136 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/blang/semver/v4 v4.0.0 github.com/consensys/bavard v0.1.13 github.com/consensys/compress v0.1.0 - github.com/consensys/gnark-crypto v0.12.2-0.20231221131605-1db1afbeb890 + github.com/consensys/gnark-crypto v0.12.2-0.20231221171913-5d5eded6bb15 github.com/fxamacker/cbor/v2 v2.5.0 github.com/google/go-cmp v0.5.9 github.com/google/pprof v0.0.0-20230817174616-7a8ec2ada47b diff --git a/go.sum b/go.sum index 1807f4f810..138bb12567 100644 --- a/go.sum +++ b/go.sum @@ -6,8 +6,8 @@ github.com/consensys/bavard v0.1.13 h1:oLhMLOFGTLdlda/kma4VOJazblc7IM5y5QPd2A/Yj github.com/consensys/bavard v0.1.13/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI= github.com/consensys/compress v0.1.0 h1:fczDaganmx2198GudPo4+5VX3eBvKy/bEJfmNotbr70= github.com/consensys/compress v0.1.0/go.mod h1:Ne8+cGKjqgjF1dlHapZx38pHzWpaBYhsKxQa+JPl0zM= -github.com/consensys/gnark-crypto v0.12.2-0.20231221131605-1db1afbeb890 h1:cBAPKf1lkrGAfUDo1iqwPM1v40KXuhSlmKfqHEW2HXY= -github.com/consensys/gnark-crypto v0.12.2-0.20231221131605-1db1afbeb890/go.mod h1:YRXoiKN6EOw8ivAgGtd934yTJlGOhJ1uQDBMMG6HZVE= +github.com/consensys/gnark-crypto v0.12.2-0.20231221171913-5d5eded6bb15 h1:mcxhrDtXKIepsKXofxSuXRst+41yzAcoNWKIotsjMTQ= +github.com/consensys/gnark-crypto v0.12.2-0.20231221171913-5d5eded6bb15/go.mod h1:wKqwsieaKPThcFkHe0d0zMsbHEUWFmZcG7KBCse210o= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/lagrange.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/lagrange.go.tmpl index 839a512d06..348c567418 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/lagrange.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/mpcsetup/lagrange.go.tmpl @@ -20,7 +20,8 @@ func lagrangeCoeffsG1(powers []curve.G1Affine, size int) []curve.G1Affine { numCPU := uint64(runtime.NumCPU()) maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) - difFFTG1(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + twiddlesInv, _ := domain.TwiddlesInv() + difFFTG1(coeffs, twiddlesInv, 0, maxSplits, nil) bitReverse(coeffs) var invBigint big.Int @@ -42,7 +43,8 @@ func lagrangeCoeffsG2(powers []curve.G2Affine, size int) []curve.G2Affine { numCPU := uint64(runtime.NumCPU()) maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(numCPU)) - difFFTG2(coeffs, domain.TwiddlesInv, 0, maxSplits, nil) + twiddlesInv, _ := domain.TwiddlesInv() + difFFTG2(coeffs, twiddlesInv, 0, maxSplits, nil) bitReverse(coeffs) var invBigint big.Int diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.marshal.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.marshal.go.tmpl index 9d4104ed31..59de0bd7e0 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.marshal.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.marshal.go.tmpl @@ -1,10 +1,7 @@ import ( {{ template "import_curve" . }} - {{ template "import_fr" . }} {{ template "import_kzg" . }} - "github.com/consensys/gnark-crypto/ecc/{{toLower .Curve}}/fr/iop" "io" - "errors" ) // WriteRawTo writes binary encoding of Proof to w without point compression @@ -96,19 +93,7 @@ func (pk *ProvingKey) writeTo(w io.Writer, withCompression bool) (n int64, err e return } - // fft domains - n2, err := pk.Domain[0].WriteTo(w) - if err != nil { - return - } - n += n2 - - n2, err = pk.Domain[1].WriteTo(w) - if err != nil { - return - } - n += n2 - + var n2 int64 // KZG key if withCompression { n2, err = pk.Kzg.WriteTo(w) @@ -129,35 +114,8 @@ func (pk *ProvingKey) writeTo(w io.Writer, withCompression bool) (n int64, err e } n += n2 - // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) - if len(pk.trace.S) != (3 * int(pk.Domain[0].Cardinality)) { - return n, errors.New("invalid permutation size, expected 3*domain cardinality") - } - - enc := curve.NewEncoder(w) - // note: type Polynomial, which is handled by default binary.Write(...) op and doesn't - // encode the size (nor does it convert from Montgomery to Regular form) - // so we explicitly transmit []fr.Element - toEncode := []interface{}{ - pk.trace.Ql.Coefficients(), - pk.trace.Qr.Coefficients(), - pk.trace.Qm.Coefficients(), - pk.trace.Qo.Coefficients(), - pk.trace.Qk.Coefficients(), - coefficients(pk.trace.Qcp), - pk.trace.S1.Coefficients(), - pk.trace.S2.Coefficients(), - pk.trace.S3.Coefficients(), - pk.trace.S, - } - for _, v := range toEncode { - if err := enc.Encode(v); err != nil { - return n + enc.BytesWritten(), err - } - } - - return n + enc.BytesWritten(), nil + return n, nil } // ReadFrom reads from binary representation in r into ProvingKey @@ -177,18 +135,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err return n, err } - n2, err, chDomain0 := pk.Domain[0].AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - - n2, err, chDomain1 := pk.Domain[1].AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - + var n2 int64 if withSubgroupChecks { n2, err = pk.Kzg.ReadFrom(r) } else { @@ -204,99 +151,7 @@ func (pk *ProvingKey) readFrom(r io.Reader, withSubgroupChecks bool) (int64, err n2, err = pk.KzgLagrange.UnsafeReadFrom(r) } n += n2 - if err != nil { - return n, err - } - - pk.trace.S = make([]int64, 3*pk.Domain[0].Cardinality) - - dec := curve.NewDecoder(r) - - var ql, qr, qm, qo, qk, s1, s2, s3 []fr.Element - var qcp [][]fr.Element - - // TODO @gbotrel: this is a bit ugly, we should probably refactor this. - // The order of the variables is important, as it matches the order in which they are - // encoded in the WriteTo(...) method. - - // Note: instead of calling dec.Decode(...) for each of the above variables, - // we call AsyncReadFrom when possible which allows to consume bytes from the reader - // and perform the decoding in parallel - - type v struct { - data *fr.Vector - chErr chan error - } - - vectors := make([]v, 8) - vectors[0] = v{data: (*fr.Vector)(&ql)} - vectors[1] = v{data: (*fr.Vector)(&qr)} - vectors[2] = v{data: (*fr.Vector)(&qm)} - vectors[3] = v{data: (*fr.Vector)(&qo)} - vectors[4] = v{data: (*fr.Vector)(&qk)} - vectors[5] = v{data: (*fr.Vector)(&s1)} - vectors[6] = v{data: (*fr.Vector)(&s2)} - vectors[7] = v{data: (*fr.Vector)(&s3)} - - // read ql, qr, qm, qo, qk - for i := 0; i < 5; i++ { - n2, err, ch := vectors[i].data.AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - vectors[i].chErr = ch - } - - // read qcp - if err := dec.Decode(&qcp); err != nil { - return n + dec.BytesRead(), err - } - - // read lqk, s1, s2, s3 - for i := 5; i < 8; i++ { - n2, err, ch := vectors[i].data.AsyncReadFrom(r) - n += n2 - if err != nil { - return n, err - } - vectors[i].chErr = ch - } - - // read pk.Trace.S - if err := dec.Decode(&pk.trace.S); err != nil { - return n + dec.BytesRead(), err - } - - // wait for all AsyncReadFrom(...) to complete - for i := range vectors { - if err := <-vectors[i].chErr; err != nil { - return n, err - } - } - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - pk.trace.Ql = iop.NewPolynomial(&ql, canReg) - pk.trace.Qr = iop.NewPolynomial(&qr, canReg) - pk.trace.Qm = iop.NewPolynomial(&qm, canReg) - pk.trace.Qo = iop.NewPolynomial(&qo, canReg) - pk.trace.Qk = iop.NewPolynomial(&qk, canReg) - pk.trace.S1 = iop.NewPolynomial(&s1, canReg) - pk.trace.S2 = iop.NewPolynomial(&s2, canReg) - pk.trace.S3 = iop.NewPolynomial(&s3, canReg) - - pk.trace.Qcp = make([]*iop.Polynomial, len(qcp)) - for i := range qcp { - pk.trace.Qcp[i] = iop.NewPolynomial(&qcp[i], canReg) - } - - // wait for FFT to be precomputed - <-chDomain0 - <-chDomain1 - - - return n + dec.BytesRead(), nil - + return n, err } // WriteTo writes binary encoding of VerifyingKey to w diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl index 9ff314e59e..6f4a971095 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl @@ -115,9 +115,6 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts // solve constraints g.Go(instance.solveConstraints) - // compute numerator data - g.Go(instance.initComputeNumerator) - // complete qk g.Go(instance.completeQk) @@ -188,9 +185,6 @@ type instance struct { // challenges gamma, beta, alpha, zeta fr.Element - // compute numerator data - cres, twiddles0, cosetTableRev, twiddlesRev []fr.Element - // channel to wait for the steps chLRO, chQk, @@ -201,8 +195,11 @@ type instance struct { chZOpening, chLinearizedPolynomial, chFoldedH, - chNumeratorInit, chGammaBeta chan struct{} + + domain0, domain1 *fft.Domain + + trace *Trace } func newInstance(ctx context.Context, spr *cs.SparseR1CS, pk *ProvingKey, fullWitness witness.Witness, opts *backend.ProverConfig) (*instance, error) { @@ -230,44 +227,31 @@ func newInstance(ctx context.Context, spr *cs.SparseR1CS, pk *ProvingKey, fullWi chLinearizedPolynomial: make(chan struct{}, 1), chFoldedH: make(chan struct{}, 1), chRestoreLRO: make(chan struct{}, 1), - chNumeratorInit: make(chan struct{}, 1), } s.initBSB22Commitments() s.setupGKRHints() s.x = make([]*iop.Polynomial, id_Qci+2*len(s.commitmentInfo)) - return &s, nil -} + // init fft domains + nbConstraints := spr.GetNbConstraints() + sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints + s.domain0 = fft.NewDomain(sizeSystem) -func (s *instance) initComputeNumerator() error { - n := s.pk.Domain[0].Cardinality - s.cres = make([]fr.Element, s.pk.Domain[1].Cardinality) - s.twiddles0 = make([]fr.Element, n) - if n == 1 { - // edge case - s.twiddles0[0].SetOne() + // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, + // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases + // except when n<6. + if sizeSystem < 6 { + s.domain1 = fft.NewDomain(8 * sizeSystem, fft.WithoutPrecompute()) } else { - copy(s.twiddles0, s.pk.Domain[0].Twiddles[0]) - for i := len(s.pk.Domain[0].Twiddles[0]); i < len(s.twiddles0); i++ { - s.twiddles0[i].Mul(&s.twiddles0[i-1], &s.twiddles0[1]) - } + s.domain1 = fft.NewDomain(4 * sizeSystem, fft.WithoutPrecompute()) } - - - cosetTable := s.pk.Domain[0].CosetTable - twiddles := s.pk.Domain[1].Twiddles[0][:n] - - s.cosetTableRev = make([]fr.Element, len(cosetTable)) - copy(s.cosetTableRev, cosetTable) - fft.BitReverse(s.cosetTableRev) + // TODO @gbotrel domain1 is used for only 1 FFT --> precomputing the twiddles + // and storing them in memory is costly given its size. --> do a FFT on the fly - s.twiddlesRev = make([]fr.Element, len(twiddles)) - copy(s.twiddlesRev, twiddles) - fft.BitReverse(s.twiddlesRev) + // build trace + s.trace = NewTrace(spr, s.domain0) - close(s.chNumeratorInit) - - return nil + return &s, nil } func (s *instance) initBlindingPolynomials() error { @@ -299,7 +283,7 @@ func (s *instance) bsb22Hint(_ *big.Int, ins, outs []*big.Int) error { res := &s.commitmentVal[commDepth] commitmentInfo := s.spr.CommitmentInfo.(constraint.PlonkCommitments)[commDepth] - committedValues := make([]fr.Element, s.pk.Domain[0].Cardinality) + committedValues := make([]fr.Element, s.domain0.Cardinality) offset := s.spr.GetNbPublicVariables() for i := range ins { committedValues[offset+commitmentInfo.Committed[i]].SetBigInt(ins[i]) @@ -314,7 +298,6 @@ func (s *instance) bsb22Hint(_ *big.Int, ins, outs []*big.Int) error { if s.proof.Bsb22Commitments[commDepth], err = kzg.Commit(s.cCommitments[commDepth].Coefficients(), s.pk.KzgLagrange); err != nil { return err } - s.cCommitments[commDepth].ToCanonical(&s.pk.Domain[0]).ToRegular() s.htfFunc.Write(s.proof.Bsb22Commitments[commDepth].Marshal()) hashBts := s.htfFunc.Sum(nil) @@ -373,7 +356,7 @@ func (s *instance) solveConstraints() error { } func (s *instance) completeQk() error { - qk := s.pk.trace.Qk.Clone().ToLagrange(&s.pk.Domain[0]).ToRegular() + qk := s.trace.Qk.Clone() qkCoeffs := qk.Coefficients() wWitness, ok := s.fullWitness.Vector().(fr.Vector) @@ -472,7 +455,7 @@ func (s *instance) commitToPolyAndBlinding(p, b *iop.Polynomial) (commit curve.G commit, err = kzg.Commit(p.Coefficients(), s.pk.KzgLagrange) // we add in the blinding contribution - n := int(s.pk.Domain[0].Cardinality) + n := int(s.domain0.Cardinality) cb := commitBlindingFactor(n, b, s.pk.Kzg) commit.Add(&commit, &cb) @@ -496,20 +479,19 @@ func (s *instance) deriveZeta() (err error) { // evaluateConstraints computes H func (s *instance) evaluateConstraints() (err error) { - // clone polys from the proving key. - s.x[id_Ql] = s.pk.trace.Ql.Clone() - s.x[id_Qr] = s.pk.trace.Qr.Clone() - s.x[id_Qm] = s.pk.trace.Qm.Clone() - s.x[id_Qo] = s.pk.trace.Qo.Clone() - s.x[id_S1] = s.pk.trace.S1.Clone() - s.x[id_S2] = s.pk.trace.S2.Clone() - s.x[id_S3] = s.pk.trace.S3.Clone() + s.x[id_Ql] = s.trace.Ql + s.x[id_Qr] = s.trace.Qr + s.x[id_Qm] = s.trace.Qm + s.x[id_Qo] = s.trace.Qo + s.x[id_S1] = s.trace.S1 + s.x[id_S2] = s.trace.S2 + s.x[id_S3] = s.trace.S3 for i := 0; i < len(s.commitmentInfo); i++ { - s.x[id_Qci+2*i] = s.pk.trace.Qcp[i].Clone() + s.x[id_Qci+2*i] = s.trace.Qcp[i] } - n := s.pk.Domain[0].Cardinality + n := s.domain0.Cardinality lone := make([]fr.Element, n) lone[0].SetOne() @@ -521,7 +503,7 @@ func (s *instance) evaluateConstraints() (err error) { } for i := 0; i < len(s.commitmentInfo); i++ { - s.x[id_Qci+2*i+1] = s.cCommitments[i].Clone() + s.x[id_Qci+2*i+1] = s.cCommitments[i] } // wait for Z to be committed or context done @@ -549,7 +531,7 @@ func (s *instance) evaluateConstraints() (err error) { return err } - s.h, err = divideByXMinusOne(numerator, [2]*fft.Domain{&s.pk.Domain[0], &s.pk.Domain[1]}) + s.h, err = divideByXMinusOne(numerator, [2]*fft.Domain{s.domain0, s.domain1}) if err != nil { return err } @@ -591,11 +573,11 @@ func (s *instance) buildRatioCopyConstraint() (err error) { s.x[id_R], s.x[id_O], }, - s.pk.trace.S, + s.trace.S, s.beta, s.gamma, iop.Form{Basis: iop.Lagrange, Layout: iop.Regular}, - &s.pk.Domain[0], + s.domain0, ) if err != nil { return err @@ -630,17 +612,17 @@ func (s *instance) openZ() (err error) { } func (s *instance) h1() []fr.Element { - h1 := s.h.Coefficients()[:s.pk.Domain[0].Cardinality+2] + h1 := s.h.Coefficients()[:s.domain0.Cardinality+2] return h1 } func (s *instance) h2() []fr.Element { - h2 := s.h.Coefficients()[s.pk.Domain[0].Cardinality+2 : 2*(s.pk.Domain[0].Cardinality+2)] + h2 := s.h.Coefficients()[s.domain0.Cardinality+2 : 2*(s.domain0.Cardinality+2)] return h2 } func (s *instance) h3() []fr.Element { - h3 := s.h.Coefficients()[2*(s.pk.Domain[0].Cardinality+2) : 3*(s.pk.Domain[0].Cardinality+2)] + h3 := s.h.Coefficients()[2*(s.domain0.Cardinality+2) : 3*(s.domain0.Cardinality+2)] return h3 } @@ -653,7 +635,7 @@ func (s *instance) foldH() error { case <-s.chH: } var n big.Int - n.SetUint64(s.pk.Domain[0].Cardinality + 2) + n.SetUint64(s.domain0.Cardinality + 2) var zetaPowerNplusTwo fr.Element zetaPowerNplusTwo.Exp(s.zeta, &n) @@ -669,7 +651,7 @@ func (s *instance) foldH() error { h2 := s.h2() s.foldedH = s.h3() - for i := 0; i < int(s.pk.Domain[0].Cardinality)+2; i++ { + for i := 0; i < int(s.domain0.Cardinality)+2; i++ { s.foldedH[i]. Mul(&s.foldedH[i], &zetaPowerNplusTwo). Add(&s.foldedH[i], &h2[i]). @@ -698,7 +680,7 @@ func (s *instance) computeLinearizedPolynomial() error { for i := 0; i < len(s.commitmentInfo); i++ { go func(i int) { - qcpzeta[i] = s.pk.trace.Qcp[i].Evaluate(s.zeta) + qcpzeta[i] = s.trace.Qcp[i].Evaluate(s.zeta) wg.Done() }(i) } @@ -728,7 +710,7 @@ func (s *instance) computeLinearizedPolynomial() error { wg.Wait() - s.linearizedPolynomial = computeLinearizedPolynomial( + s.linearizedPolynomial = s.innerComputeLinearizedPoly( blzeta, brzeta, bozeta, @@ -753,9 +735,6 @@ func (s *instance) computeLinearizedPolynomial() error { } func (s *instance) batchOpening() error { - polysQcp := coefficients(s.pk.trace.Qcp) - polysToOpen := make([][]fr.Element, 7+len(polysQcp)) - copy(polysToOpen[7:], polysQcp) // wait for LRO to be committed (or ctx.Done()) select { @@ -778,13 +757,17 @@ func (s *instance) batchOpening() error { case <-s.chLinearizedPolynomial: } + polysQcp := coefficients(s.trace.Qcp) + polysToOpen := make([][]fr.Element, 7+len(polysQcp)) + copy(polysToOpen[7:], polysQcp) + polysToOpen[0] = s.foldedH polysToOpen[1] = s.linearizedPolynomial polysToOpen[2] = getBlindedCoefficients(s.x[id_L], s.bp[id_Bl]) polysToOpen[3] = getBlindedCoefficients(s.x[id_R], s.bp[id_Br]) polysToOpen[4] = getBlindedCoefficients(s.x[id_O], s.bp[id_Bo]) - polysToOpen[5] = s.pk.trace.S1.Coefficients() - polysToOpen[6] = s.pk.trace.S2.Coefficients() + polysToOpen[5] = s.trace.S1.Coefficients() + polysToOpen[6] = s.trace.S2.Coefficients() digestsToOpen := make([]curve.G1Affine, len(s.pk.Vk.Qcp)+7) copy(digestsToOpen[7:], s.pk.Vk.Qcp) @@ -813,6 +796,23 @@ func (s *instance) batchOpening() error { // evaluate the full set of constraints, all polynomials in x are back in // canonical regular form at the end func (s *instance) computeNumerator() (*iop.Polynomial, error) { + // init vectors that are used multiple times throughout the computation + n := s.domain0.Cardinality + twiddles0 := make([]fr.Element, n) + if n == 1 { + // edge case + twiddles0[0].SetOne() + } else { + twiddles, err := s.domain0.Twiddles() + if err != nil { + return nil, err + } + copy(twiddles0, twiddles[0]) + w := twiddles0[1] + for i := len(twiddles[0]); i < len(twiddles0); i++ { + twiddles0[i].Mul(&twiddles0[i-1], &w) + } + } // wait for chQk to be closed (or ctx.Done()) select { @@ -821,7 +821,7 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { case <-s.chQk: } - n := s.pk.Domain[0].Cardinality + nbBsbGates := (len(s.x) - id_Qci + 1) >> 1 @@ -846,7 +846,7 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { var cs, css fr.Element - cs.Set(&s.pk.Domain[1].FrMultiplicativeGen) + cs.Set(&s.domain1.FrMultiplicativeGen) css.Square(&cs) orderingConstraint := func(u ...fr.Element) fr.Element { @@ -878,11 +878,11 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { return res } - rho := int(s.pk.Domain[1].Cardinality / n) + rho := int(s.domain1.Cardinality / n) shifters := make([]fr.Element, rho) - shifters[0].Set(&s.pk.Domain[1].FrMultiplicativeGen) + shifters[0].Set(&s.domain1.FrMultiplicativeGen) for i := 1; i < rho; i++ { - shifters[i].Set(&s.pk.Domain[1].Generator) + shifters[i].Set(&s.domain1.Generator) } // stores the current coset shifter @@ -893,14 +893,13 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { one.SetOne() bn := big.NewInt(int64(n)) - // wait for init go routine - <-s.chNumeratorInit - - cosetTable := s.pk.Domain[0].CosetTable - twiddles := s.pk.Domain[1].Twiddles[0][:n] + cosetTable, err := s.domain0.CosetTable() + if err != nil { + return nil, err + } // init the result polynomial & buffer - cres := s.cres + cres := make([]fr.Element, s.domain1.Cardinality) buf := make([]fr.Element, n) var wgBuf sync.WaitGroup @@ -912,17 +911,17 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { // blind L, R, O, Z, ZS var y fr.Element - y = s.bp[id_Bl].Evaluate(s.twiddles0[i]) + y = s.bp[id_Bl].Evaluate(twiddles0[i]) u[id_L].Add(&u[id_L], &y) - y = s.bp[id_Br].Evaluate(s.twiddles0[i]) + y = s.bp[id_Br].Evaluate(twiddles0[i]) u[id_R].Add(&u[id_R], &y) - y = s.bp[id_Bo].Evaluate(s.twiddles0[i]) + y = s.bp[id_Bo].Evaluate(twiddles0[i]) u[id_O].Add(&u[id_O], &y) - y = s.bp[id_Bz].Evaluate(s.twiddles0[i]) + y = s.bp[id_Bz].Evaluate(twiddles0[i]) u[id_Z].Add(&u[id_Z], &y) // ZS is shifted by 1; need to get correct twiddle - y = s.bp[id_Bz].Evaluate(s.twiddles0[(i+1)%int(n)]) + y = s.bp[id_Bz].Evaluate(twiddles0[(i+1)%int(n)]) u[id_ZS].Add(&u[id_ZS], &y) a := gateConstraint(u...) @@ -932,28 +931,15 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { return c } - // select the correct scaling vector to scale by shifter[i] - selectScalingVector := func(i int, l iop.Layout) []fr.Element { - var w []fr.Element - if i == 0 { - if l == iop.Regular { - w = cosetTable - } else { - w = s.cosetTableRev - } - } else { - if l == iop.Regular { - w = twiddles - } else { - w = s.twiddlesRev - } - } - return w - } + // for the first iteration, the scalingVector is the coset table + scalingVector := cosetTable + scalingVectorRev := make([]fr.Element, len(cosetTable)) + copy(scalingVectorRev, cosetTable) + fft.BitReverse(scalingVectorRev) // pre-computed to compute the bit reverse index // of the result polynomial - m := uint64(s.pk.Domain[1].Cardinality) + m := uint64(s.domain1.Cardinality) mm := uint64(64 - bits.TrailingZeros64(m)) for i := 0; i < rho; i++ { @@ -970,6 +956,17 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { acc.Mul(&acc, &shifters[i]) } } + if i == 1 { + // we have to update the scalingVector; instead of scaling by + // cosets we scale by the twiddles of the large domain. + w := s.domain1.Generator + scalingVector = make([]fr.Element, n) + fft.BuildExpTable(w, scalingVector) + + // reuse memory + copy(scalingVectorRev, scalingVector) + fft.BitReverse(scalingVectorRev) + } // we do **a lot** of FFT here, but on the small domain. // note that for all the polynomials in the proving key @@ -979,10 +976,15 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { batchApply(s.x, func(p *iop.Polynomial) { nbTasks := calculateNbTasks(len(s.x)-1) * 2 // shift polynomials to be in the correct coset - p.ToCanonical(&s.pk.Domain[0], nbTasks) + p.ToCanonical(s.domain0, nbTasks) // scale by shifter[i] - w := selectScalingVector(i, p.Layout) + var w []fr.Element + if p.Layout == iop.Regular { + w = scalingVector + } else { + w = scalingVectorRev + } cp := p.Coefficients() utils.Parallelize(len(cp), func(start, end int) { @@ -992,7 +994,7 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { }, nbTasks) // fft in the correct coset - p.ToLagrange(&s.pk.Domain[0], nbTasks).ToRegular() + p.ToLagrange(s.domain0, nbTasks).ToRegular() }) wgBuf.Wait() @@ -1025,9 +1027,10 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { // scale everything back go func() { - for i := id_ZS; i < len(s.x); i++ { - s.x[i] = nil - } + s.x[id_ID] = nil + s.x[id_LOne] = nil + s.x[id_ZS] = nil + s.x[id_Qk] = nil var cs fr.Element cs.Set(&shifters[0]) @@ -1036,8 +1039,11 @@ func (s *instance) computeNumerator() (*iop.Polynomial, error) { } cs.Inverse(&cs) - batchApply(s.x[:id_ZS], func(p *iop.Polynomial) { - p.ToCanonical(&s.pk.Domain[0], 8).ToRegular() + batchApply(s.x, func(p *iop.Polynomial) { + if p == nil { + return + } + p.ToCanonical(s.domain0, 8).ToRegular() scalePowers(p, cs) }) @@ -1254,7 +1260,7 @@ func evaluateXnMinusOneDomainBigCoset(domains [2]*fft.Domain) []fr.Element { return res } -// computeLinearizedPolynomial computes the linearized polynomial in canonical basis. +// innerComputeLinearizedPoly computes the linearized polynomial in canonical basis. // The purpose is to commit and open all in one ql, qr, qm, qo, qk. // * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta // * z is the permutation polynomial, zu is Z(μX), the shifted version of Z @@ -1265,8 +1271,8 @@ func evaluateXnMinusOneDomainBigCoset(domains [2]*fft.Domain) []fr.Element { // α²*L₁(ζ)*Z(X) // + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) // + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) -func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, qcpZeta, blindedZCanonical []fr.Element, pi2Canonical [][]fr.Element, pk *ProvingKey) []fr.Element { - +func (s *instance) innerComputeLinearizedPoly(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, qcpZeta, blindedZCanonical []fr.Element, pi2Canonical [][]fr.Element, pk *ProvingKey) []fr.Element { + // TODO @gbotrel rename // first part: individual constraints var rl fr.Element rl.Mul(&rZeta, &lZeta) @@ -1276,12 +1282,12 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, var s1, s2 fr.Element chS1 := make(chan struct{}, 1) go func() { - s1 = pk.trace.S1.Evaluate(zeta) // s1(ζ) + s1 = s.trace.S1.Evaluate(zeta) // s1(ζ) s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) close(chS1) }() // ps2 := iop.NewPolynomial(&pk.S2Canonical, iop.Form{Basis: iop.Canonical, Layout: iop.Regular}) - tmp := pk.trace.S2.Evaluate(zeta) // s2(ζ) + tmp := s.trace.S2.Evaluate(zeta) // s2(ζ) tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) @@ -1300,7 +1306,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, // third part L₁(ζ)*α²*Z var lagrangeZeta, one, den, frNbElmt fr.Element one.SetOne() - nbElmt := int64(pk.Domain[0].Cardinality) + nbElmt := int64(s.domain0.Cardinality) lagrangeZeta.Set(&zeta). Exp(lagrangeZeta, big.NewInt(nbElmt)). Sub(&lagrangeZeta, &one) @@ -1310,17 +1316,19 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) Mul(&lagrangeZeta, &alpha). Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) + Mul(&lagrangeZeta, &s.domain0.CardinalityInv) // (1/n)*α²*L₁(ζ) + + s3canonical := s.trace.S3.Coefficients() - s3canonical := pk.trace.S3.Coefficients() + s.trace.Qk.ToCanonical(s.domain0).ToRegular() utils.Parallelize(len(blindedZCanonical), func(start, end int) { - cql := pk.trace.Ql.Coefficients() - cqr := pk.trace.Qr.Coefficients() - cqm := pk.trace.Qm.Coefficients() - cqo := pk.trace.Qo.Coefficients() - cqk := pk.trace.Qk.Coefficients() + cql := s.trace.Ql.Coefficients() + cqr := s.trace.Qr.Coefficients() + cqm := s.trace.Qm.Coefficients() + cqo := s.trace.Qo.Coefficients() + cqk := s.trace.Qk.Coefficients() var t, t0, t1 fr.Element diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl index bba64e9a40..8401930fc9 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl @@ -60,34 +60,14 @@ type Trace struct { S []int64 } -// ProvingKey stores the data needed to generate a proof: -// * the commitment scheme -// * ql, prepended with as many ones as they are public inputs -// * qr, qm, qo prepended with as many zeroes as there are public inputs. -// * qk, prepended with as many zeroes as public inputs, to be completed by the prover -// with the list of public inputs. -// * sigma_1, sigma_2, sigma_3 in both basis -// * the copy constraint permutation +// ProvingKey stores the data needed to generate a proof type ProvingKey struct { - // stores ql, qr, qm, qo, qk (-> to be completed by the prover) - // and s1, s2, s3. They are set in canonical basis before generating the proof, they will be used - // for computing the opening proofs (hence the canonical form). The canonical version - // of qk incomplete is used in the linearisation polynomial. - // The polynomials in trace are in canonical basis. - trace Trace - Kzg, KzgLagrange kzg.ProvingKey // Verifying Key is embedded into the proving key (needed by Prove) Vk *VerifyingKey - - // Domains used for the FFTs. - // Domain[0] = small Domain - // Domain[1] = big Domain - Domain [2]fft.Domain } -// TODO modify the signature to receive the SRS in Lagrange form (optional argument ?) func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *VerifyingKey, error) { var pk ProvingKey @@ -96,26 +76,26 @@ func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *Verifyin vk.CommitmentConstraintIndexes = internal.IntSliceToUint64Slice(spr.CommitmentInfo.CommitmentIndexes()) // step 0: set the fft domains - pk.initDomains(spr) - if pk.Domain[0].Cardinality < 2 { + domain := initFFTDomain(spr) + if domain.Cardinality < 2 { return nil, nil, fmt.Errorf("circuit has only %d constraints; unsupported by the current implementation", spr.GetNbConstraints()) } // check the size of the kzg srs. - if len(srs.Pk.G1) < (int(pk.Domain[0].Cardinality) + 3) { // + 3 for the kzg.Open of blinded poly - return nil, nil, fmt.Errorf("kzg srs is too small: got %d, need %d", len(srs.Pk.G1), pk.Domain[0].Cardinality+3) + if len(srs.Pk.G1) < (int(domain.Cardinality) + 3) { // + 3 for the kzg.Open of blinded poly + return nil, nil, fmt.Errorf("kzg srs is too small: got %d, need %d", len(srs.Pk.G1), domain.Cardinality+3) } // same for the lagrange form - if len(srsLagrange.Pk.G1) != int(pk.Domain[0].Cardinality) { - return nil, nil, fmt.Errorf("kzg srs lagrange is too small: got %d, need %d", len(srsLagrange.Pk.G1), pk.Domain[0].Cardinality) + if len(srsLagrange.Pk.G1) != int(domain.Cardinality) { + return nil, nil, fmt.Errorf("kzg srs lagrange is too small: got %d, need %d", len(srsLagrange.Pk.G1), domain.Cardinality) } // step 1: set the verifying key - pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) - vk.Size = pk.Domain[0].Cardinality + vk.CosetShift.Set(&domain.FrMultiplicativeGen) + vk.Size = domain.Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.Domain[0].Generator) + vk.Generator.Set(&domain.Generator) vk.NbPublicVariables = uint64(len(spr.Public)) pk.Kzg.G1 = srs.Pk.G1[:int(vk.Size)+3] @@ -123,22 +103,15 @@ func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *Verifyin vk.Kzg = srs.Vk // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis - BuildTrace(spr, &pk.trace) - // step 3: build the permutation and build the polynomials S1, S2, S3 to encode the permutation. // Note: at this stage, the permutation takes in account the placeholders - nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - buildPermutation(spr, &pk.trace, nbVariables) - s := computePermutationPolynomials(&pk.trace, &pk.Domain[0]) - pk.trace.S1 = s[0] - pk.trace.S2 = s[1] - pk.trace.S3 = s[2] + trace := NewTrace(spr, domain) // step 4: commit to s1, s2, s3, ql, qr, qm, qo, and (the incomplete version of) qk. // All the above polynomials are expressed in canonical basis afterwards. This is why // we save lqk before, because the prover needs to complete it in Lagrange form, and // then express it on the Lagrange coset basis. - if err := commitTrace(&pk.trace, &pk); err != nil { + if err := vk.commitTrace(trace, domain, pk.KzgLagrange); err != nil { return nil, nil, err } @@ -155,9 +128,13 @@ func (pk *ProvingKey) VerifyingKey() interface{} { return pk.Vk } -// BuildTrace fills the constant columns ql, qr, qm, qo, qk from the sparser1cs. -// Size is the size of the system that is nb_constraints+nb_public_variables -func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { +// NewTrace returns a new Trace object from the constraint system. +// It fills the constant columns ql, qr, qm, qo, qk, and qcp with the +// coefficients of the constraints. +// Size is the size of the system that is next power of 2 (nb_constraints+nb_public_variables) +// The permutation is also computed and stored in the Trace. +func NewTrace(spr *cs.SparseR1CS, domain *fft.Domain) *Trace { + var trace Trace nbConstraints := spr.GetNbConstraints() sizeSystem := uint64(nbConstraints + len(spr.Public)) @@ -193,85 +170,76 @@ func BuildTrace(spr *cs.SparseR1CS, pt *Trace) { lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} - pt.Ql = iop.NewPolynomial(&ql, lagReg) - pt.Qr = iop.NewPolynomial(&qr, lagReg) - pt.Qm = iop.NewPolynomial(&qm, lagReg) - pt.Qo = iop.NewPolynomial(&qo, lagReg) - pt.Qk = iop.NewPolynomial(&qk, lagReg) - pt.Qcp = make([]*iop.Polynomial, len(qcp)) + trace.Ql = iop.NewPolynomial(&ql, lagReg) + trace.Qr = iop.NewPolynomial(&qr, lagReg) + trace.Qm = iop.NewPolynomial(&qm, lagReg) + trace.Qo = iop.NewPolynomial(&qo, lagReg) + trace.Qk = iop.NewPolynomial(&qk, lagReg) + trace.Qcp = make([]*iop.Polynomial, len(qcp)) for i := range commitmentInfo { qcp[i] = make([]fr.Element, size) for _, committed := range commitmentInfo[i].Committed { qcp[i][offset+committed].SetOne() } - pt.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) + trace.Qcp[i] = iop.NewPolynomial(&qcp[i], lagReg) } + + // build the permutation and build the polynomials S1, S2, S3 to encode the permutation. + // Note: at this stage, the permutation takes in account the placeholders + nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) + buildPermutation(spr, &trace, nbVariables) + s := computePermutationPolynomials(&trace, domain) + trace.S1 = s[0] + trace.S2 = s[1] + trace.S3 = s[2] + + return &trace } // commitTrace commits to every polynomial in the trace, and put // the commitments int the verifying key. -func commitTrace(trace *Trace, pk *ProvingKey) error { +func (vk *VerifyingKey) commitTrace(trace *Trace, domain *fft.Domain, srsPk kzg.ProvingKey) error { - trace.Ql.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qr.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qm.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qo.ToCanonical(&pk.Domain[0]).ToRegular() - trace.Qk.ToCanonical(&pk.Domain[0]).ToRegular() // -> qk is not complete - trace.S1.ToCanonical(&pk.Domain[0]).ToRegular() - trace.S2.ToCanonical(&pk.Domain[0]).ToRegular() - trace.S3.ToCanonical(&pk.Domain[0]).ToRegular() var err error - pk.Vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) + vk.Qcp = make([]kzg.Digest, len(trace.Qcp)) for i := range trace.Qcp { - trace.Qcp[i].ToCanonical(&pk.Domain[0]).ToRegular() - if pk.Vk.Qcp[i], err = kzg.Commit(pk.trace.Qcp[i].Coefficients(), pk.Kzg); err != nil { + if vk.Qcp[i], err = kzg.Commit(trace.Qcp[i].Coefficients(), srsPk); err != nil { return err } } - if pk.Vk.Ql, err = kzg.Commit(pk.trace.Ql.Coefficients(), pk.Kzg); err != nil { + if vk.Ql, err = kzg.Commit(trace.Ql.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qr, err = kzg.Commit(pk.trace.Qr.Coefficients(), pk.Kzg); err != nil { + if vk.Qr, err = kzg.Commit(trace.Qr.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qm, err = kzg.Commit(pk.trace.Qm.Coefficients(), pk.Kzg); err != nil { + if vk.Qm, err = kzg.Commit(trace.Qm.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qo, err = kzg.Commit(pk.trace.Qo.Coefficients(), pk.Kzg); err != nil { + if vk.Qo, err = kzg.Commit(trace.Qo.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.Qk, err = kzg.Commit(pk.trace.Qk.Coefficients(), pk.Kzg); err != nil { + if vk.Qk, err = kzg.Commit(trace.Qk.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.S[0], err = kzg.Commit(pk.trace.S1.Coefficients(), pk.Kzg); err != nil { + if vk.S[0], err = kzg.Commit(trace.S1.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.S[1], err = kzg.Commit(pk.trace.S2.Coefficients(), pk.Kzg); err != nil { + if vk.S[1], err = kzg.Commit(trace.S2.Coefficients(), srsPk); err != nil { return err } - if pk.Vk.S[2], err = kzg.Commit(pk.trace.S3.Coefficients(), pk.Kzg); err != nil { + if vk.S[2], err = kzg.Commit(trace.S3.Coefficients(), srsPk); err != nil { return err } return nil } -func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { - +func initFFTDomain(spr *cs.SparseR1CS) *fft.Domain { nbConstraints := spr.GetNbConstraints() sizeSystem := uint64(nbConstraints + len(spr.Public)) // len(spr.Public) is for the placeholder constraints - pk.Domain[0] = *fft.NewDomain(sizeSystem) - - // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, - // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases - // except when n<6. - if sizeSystem < 6 { - pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) - } else { - pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) - } - + return fft.NewDomain(sizeSystem, fft.WithoutPrecompute()) } // buildPermutation builds the Permutation associated with a circuit. @@ -286,10 +254,10 @@ func (pk *ProvingKey) initDomains(spr *cs.SparseR1CS) { // The permutation is encoded as a slice s of size 3*size(l), where the // i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab // like this: for i in tab: tab[i] = tab[permutation[i]] -func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { +func buildPermutation(spr *cs.SparseR1CS, trace *Trace, nbVariables int) { // nbVariables := spr.NbInternalVariables + len(spr.Public) + len(spr.Secret) - sizeSolution := len(pt.Ql.Coefficients()) + sizeSolution := len(trace.Ql.Coefficients()) sizePermutation := 3 * sizeSolution // init permutation @@ -339,13 +307,13 @@ func buildPermutation(spr *cs.SparseR1CS, pt *Trace, nbVariables int) { } } - pt.S = permutation + trace.S = permutation } // computePermutationPolynomials computes the LDE (Lagrange basis) of the permutation. // We let the permutation act on || u || u^{2}, split the result in 3 parts, // and interpolate each of the 3 parts on . -func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polynomial { +func computePermutationPolynomials(trace *Trace, domain *fft.Domain) [3]*iop.Polynomial { nbElmts := int(domain.Cardinality) @@ -359,9 +327,9 @@ func computePermutationPolynomials(pt *Trace, domain *fft.Domain) [3]*iop.Polyno s2Canonical := make([]fr.Element, nbElmts) s3Canonical := make([]fr.Element, nbElmts) for i := 0; i < nbElmts; i++ { - s1Canonical[i].Set(&evaluationIDSmallDomain[pt.S[i]]) - s2Canonical[i].Set(&evaluationIDSmallDomain[pt.S[nbElmts+i]]) - s3Canonical[i].Set(&evaluationIDSmallDomain[pt.S[2*nbElmts+i]]) + s1Canonical[i].Set(&evaluationIDSmallDomain[trace.S[i]]) + s2Canonical[i].Set(&evaluationIDSmallDomain[trace.S[nbElmts+i]]) + s3Canonical[i].Set(&evaluationIDSmallDomain[trace.S[2*nbElmts+i]]) } lagReg := iop.Form{Basis: iop.Lagrange, Layout: iop.Regular} diff --git a/internal/generator/backend/template/zkpschemes/plonk/tests/marshal.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/tests/marshal.go.tmpl index de4f14da75..375c4a9d59 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/tests/marshal.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/tests/marshal.go.tmpl @@ -2,12 +2,10 @@ import ( {{ template "import_curve" . }} {{ template "import_fr" . }} - {{ template "import_fft" . }} "testing" "math/big" "math/rand" "github.com/consensys/gnark/io" - "github.com/consensys/gnark-crypto/ecc/{{toLower .Curve}}/fr/iop" "github.com/stretchr/testify/assert" ) @@ -41,8 +39,6 @@ func (pk *ProvingKey) randomize() { var vk VerifyingKey vk.randomize() pk.Vk = &vk - pk.Domain[0] = *fft.NewDomain(32) - pk.Domain[1] = *fft.NewDomain(4 * 32) pk.Kzg.G1 = make([]curve.G1Affine, 32) pk.KzgLagrange.G1 = make([]curve.G1Affine, 32) @@ -51,36 +47,6 @@ func (pk *ProvingKey) randomize() { pk.KzgLagrange.G1[i] = randomG1Point() } - n := int(pk.Domain[0].Cardinality) - ql := randomScalars(n) - qr := randomScalars(n) - qm := randomScalars(n) - qo := randomScalars(n) - qk := randomScalars(n) - s1 := randomScalars(n) - s2 := randomScalars(n) - s3 := randomScalars(n) - - canReg := iop.Form{Basis: iop.Canonical, Layout: iop.Regular} - pk.trace.Ql = iop.NewPolynomial(&ql, canReg) - pk.trace.Qr = iop.NewPolynomial(&qr, canReg) - pk.trace.Qm = iop.NewPolynomial(&qm, canReg) - pk.trace.Qo = iop.NewPolynomial(&qo, canReg) - pk.trace.Qk = iop.NewPolynomial(&qk, canReg) - pk.trace.S1 = iop.NewPolynomial(&s1, canReg) - pk.trace.S2 = iop.NewPolynomial(&s2, canReg) - pk.trace.S3 = iop.NewPolynomial(&s3, canReg) - - pk.trace.Qcp = make([]*iop.Polynomial, rand.Intn(4)) //#nosec G404 weak rng is fine here - for i := range pk.trace.Qcp { - qcp := randomScalars(rand.Intn(n / 4)) //#nosec G404 weak rng is fine here - pk.trace.Qcp[i] = iop.NewPolynomial(&qcp, canReg) - } - - pk.trace.S = make([]int64, 3*pk.Domain[0].Cardinality) - pk.trace.S[0] = -12 - pk.trace.S[len(pk.trace.S)-1] = 8888 - } func (vk *VerifyingKey) randomize() { diff --git a/internal/generator/backend/template/zkpschemes/plonkfri/plonk.prove.go.tmpl b/internal/generator/backend/template/zkpschemes/plonkfri/plonk.prove.go.tmpl index d0ec33da5b..fc6ce66f4f 100644 --- a/internal/generator/backend/template/zkpschemes/plonkfri/plonk.prove.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonkfri/plonk.prove.go.tmpl @@ -407,9 +407,13 @@ func fftBigCosetWOBitReverse(poly []fr.Element, domainBig *fft.Domain) []fr.Elem // we copy poly in res and scale by coset here // to avoid FFT scaling on domainBig.Cardinality (res is very sparse) + cosetTable, err := domainBig.CosetTable() + if err != nil { + panic(err) + } utils.Parallelize(len(poly), func(start, end int) { for i := start; i < end; i++ { - res[i].Mul(&poly[i], &domainBig.CosetTable[i]) + res[i].Mul(&poly[i], &cosetTable[i]) } }, runtime.NumCPU()/2) domainBig.FFT(res, fft.DIF)