diff --git a/backend/backend.go b/backend/backend.go index 0873e69c24..a9f3cc15f6 100644 --- a/backend/backend.go +++ b/backend/backend.go @@ -15,6 +15,13 @@ // Package backend implements Zero Knowledge Proof systems: it consumes circuit compiled with gnark/frontend. package backend +import ( + "io" + "os" + + "github.com/consensys/gnark/backend/hint" +) + // ID represent a unique ID for a proving scheme type ID uint16 @@ -40,3 +47,48 @@ func (id ID) String() string { return "unknown" } } + +// NewProverOption returns a default ProverOption with given options applied +func NewProverOption(opts ...func(opt *ProverOption) error) (ProverOption, error) { + opt := ProverOption{LoggerOut: os.Stdout} + for _, option := range opts { + if err := option(&opt); err != nil { + return ProverOption{}, err + } + } + return opt, nil +} + +// ProverOption is shared accross backends to parametrize calls to xxx.Prove(...) +type ProverOption struct { + Force bool // default to false + HintFunctions []hint.Function // default to nil (use only solver std hints) + LoggerOut io.Writer // default to os.Stdout +} + +// IgnoreSolverError is a ProverOption that indicates that the Prove algorithm +// should complete, even if constraint system is not solved. +// In that case, Prove will output an invalid Proof, but will execute all algorithms +// which is useful for test and benchmarking purposes +func IgnoreSolverError(opt *ProverOption) error { + opt.Force = true + return nil +} + +// WithHints is a Prover option that specifies additional hint functions to be used +// by the constraint solver +func WithHints(hintFunctions ...hint.Function) func(opt *ProverOption) error { + return func(opt *ProverOption) error { + opt.HintFunctions = append(opt.HintFunctions, hintFunctions...) + return nil + } +} + +// WithOutput is a Prover option that specifies an io.Writer as destination for logs printed by +// cs.Println(). If set to nil, no logs are printed. +func WithOutput(w io.Writer) func(opt *ProverOption) error { + return func(opt *ProverOption) error { + opt.LoggerOut = w + return nil + } +} diff --git a/backend/groth16/assert.go b/backend/groth16/assert.go index 068f22311d..edda493ccf 100644 --- a/backend/groth16/assert.go +++ b/backend/groth16/assert.go @@ -20,7 +20,7 @@ import ( "reflect" "testing" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" backend_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/cs" witness_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/witness" @@ -47,11 +47,11 @@ func NewAssert(t *testing.T) *Assert { } // ProverFailed check that a witness does NOT solve a circuit -func (assert *Assert) ProverFailed(r1cs frontend.CompiledConstraintSystem, witness frontend.Circuit, hintFunctions ...hint.Function) { +func (assert *Assert) ProverFailed(r1cs frontend.CompiledConstraintSystem, witness frontend.Circuit, opts ...func(opt *backend.ProverOption) error) { // setup pk, err := DummySetup(r1cs) assert.NoError(err) - _, err = Prove(r1cs, pk, witness, hintFunctions) + _, err = Prove(r1cs, pk, witness, opts...) assert.Error(err, "proving with bad witness should output an error") } @@ -68,7 +68,7 @@ func (assert *Assert) ProverFailed(r1cs frontend.CompiledConstraintSystem, witne // 5. Ensure deserialization(serialization) of generated objects is correct // // ensure result vectors a*b=c, and check other properties like random sampling -func (assert *Assert) ProverSucceeded(r1cs frontend.CompiledConstraintSystem, witness frontend.Circuit, hintFunctions ...hint.Function) { +func (assert *Assert) ProverSucceeded(r1cs frontend.CompiledConstraintSystem, witness frontend.Circuit, opts ...func(opt *backend.ProverOption) error) { // setup pk, vk, err := Setup(r1cs) assert.NoError(err) @@ -84,17 +84,17 @@ func (assert *Assert) ProverSucceeded(r1cs frontend.CompiledConstraintSystem, wi } // ensure expected Values are computed correctly - assert.SolvingSucceeded(r1cs, witness, hintFunctions...) + assert.SolvingSucceeded(r1cs, witness, opts...) // extract full witness & public witness // prover - proof, err := Prove(r1cs, pk, witness, hintFunctions) + proof, err := Prove(r1cs, pk, witness, opts...) assert.NoError(err, "proving with good witness should not output an error") // ensure random sampling; calling prove twice with same witness should produce different proof { - proof2, err := Prove(r1cs, pk, witness, hintFunctions) + proof2, err := Prove(r1cs, pk, witness, opts...) assert.NoError(err, "proving with good witness should not output an error") assert.False(reflect.DeepEqual(proof, proof2), "calling prove twice with same input should produce different proof") } @@ -130,49 +130,56 @@ func (assert *Assert) SerializationRawSucceeded(from gnarkio.WriterRawTo, to io. } // SolvingSucceeded Verifies that the R1CS is solved with the given witness, without executing groth16 workflow -func (assert *Assert) SolvingSucceeded(r1cs frontend.CompiledConstraintSystem, witness frontend.Circuit, hintFunctions ...hint.Function) { - assert.NoError(IsSolved(r1cs, witness, hintFunctions)) +func (assert *Assert) SolvingSucceeded(r1cs frontend.CompiledConstraintSystem, witness frontend.Circuit, opts ...func(opt *backend.ProverOption) error) { + assert.NoError(IsSolved(r1cs, witness, opts...)) } // SolvingFailed Verifies that the R1CS is not solved with the given witness, without executing groth16 workflow -func (assert *Assert) SolvingFailed(r1cs frontend.CompiledConstraintSystem, witness frontend.Circuit, hintFunctions ...hint.Function) { - assert.Error(IsSolved(r1cs, witness, hintFunctions)) +func (assert *Assert) SolvingFailed(r1cs frontend.CompiledConstraintSystem, witness frontend.Circuit, opts ...func(opt *backend.ProverOption) error) { + assert.Error(IsSolved(r1cs, witness, opts...)) } // IsSolved attempts to solve the constraint system with provided witness // returns nil if it succeeds, error otherwise. -func IsSolved(r1cs frontend.CompiledConstraintSystem, witness frontend.Circuit, hintFunctions []hint.Function) error { +func IsSolved(r1cs frontend.CompiledConstraintSystem, witness frontend.Circuit, opts ...func(opt *backend.ProverOption) error) error { + + // apply options + opt, err := backend.NewProverOption(opts...) + if err != nil { + return err + } + switch _r1cs := r1cs.(type) { case *backend_bls12377.R1CS: w := witness_bls12377.Witness{} if err := w.FromFullAssignment(witness); err != nil { return err } - return _r1cs.IsSolved(w, hintFunctions) + return _r1cs.IsSolved(w, opt) case *backend_bls12381.R1CS: w := witness_bls12381.Witness{} if err := w.FromFullAssignment(witness); err != nil { return err } - return _r1cs.IsSolved(w, hintFunctions) + return _r1cs.IsSolved(w, opt) case *backend_bn254.R1CS: w := witness_bn254.Witness{} if err := w.FromFullAssignment(witness); err != nil { return err } - return _r1cs.IsSolved(w, hintFunctions) + return _r1cs.IsSolved(w, opt) case *backend_bw6761.R1CS: w := witness_bw6761.Witness{} if err := w.FromFullAssignment(witness); err != nil { return err } - return _r1cs.IsSolved(w, hintFunctions) + return _r1cs.IsSolved(w, opt) case *backend_bls24315.R1CS: w := witness_bls24315.Witness{} if err := w.FromFullAssignment(witness); err != nil { return err } - return _r1cs.IsSolved(w, hintFunctions) + return _r1cs.IsSolved(w, opt) default: panic("unrecognized R1CS curve type") } diff --git a/backend/groth16/fuzz.go b/backend/groth16/fuzz.go index 998f58bc69..d761b4b03c 100644 --- a/backend/groth16/fuzz.go +++ b/backend/groth16/fuzz.go @@ -4,41 +4,42 @@ package groth16 import ( - "strings" - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/frontend" - backend_bls12381 "github.com/consensys/gnark/internal/backend/bls12-381/cs" - witness_bls12381 "github.com/consensys/gnark/internal/backend/bls12-381/witness" - backend_bn254 "github.com/consensys/gnark/internal/backend/bn254/cs" - witness_bn254 "github.com/consensys/gnark/internal/backend/bn254/witness" + // backend_bls12381 "github.com/consensys/gnark/internal/backend/bls12-381/cs" + // witness_bls12381 "github.com/consensys/gnark/internal/backend/bls12-381/witness" + // backend_bn254 "github.com/consensys/gnark/internal/backend/bn254/cs" + // witness_bn254 "github.com/consensys/gnark/internal/backend/bn254/witness" ) +// TODO FIXME @gbotrel func Fuzz(data []byte) int { curves := []ecc.ID{ecc.BN254, ecc.BLS12_381} for _, curveID := range curves { - ccs, nbAssertions := frontend.CsFuzzed(data, curveID) - _, s, p := ccs.GetNbVariables() - wSize := s + p - 1 - ccs.SetLoggerOutput(nil) - switch _r1cs := ccs.(type) { - case *backend_bls12381.R1CS: - w := make(witness_bls12381.Witness, wSize) - // make w random - err := _r1cs.IsSolved(w, nil) - if nbAssertions == 0 && err != nil && !strings.Contains(err.Error(), "couldn't solve computational constraint") { - panic("no assertions, yet solving resulted in an error.") - } - case *backend_bn254.R1CS: - w := make(witness_bn254.Witness, wSize) - // make w random - err := _r1cs.IsSolved(w, nil) - if nbAssertions == 0 && err != nil && !strings.Contains(err.Error(), "couldn't solve computational constraint") { - panic("no assertions, yet solving resulted in an error.") - } - default: - panic("unrecognized R1CS curve type") - } + frontend.CsFuzzed(data, curveID) + // _, s, p := ccs.GetNbVariables() + // wSize := s + p - 1 + // ccs.SetLoggerOutput(nil) + // switch _r1cs := ccs.(type) { + // case *backend_bls12381.R1CS: + // w := make(witness_bls12381.Witness, wSize) + // // make w random + // _ = _r1cs.IsSolved(w, nil) + // // TODO FIXME @gbotrel + // // if nbAssertions == 0 && err != nil && !strings.Contains(err.Error(), "couldn't solve computational constraint") { + // // panic("no assertions, yet solving resulted in an error.") + // // } + // case *backend_bn254.R1CS: + // w := make(witness_bn254.Witness, wSize) + // // make w random + // _ = _r1cs.IsSolved(w, nil) + // // TODO FIXME @gbotrel + // // if nbAssertions == 0 && err != nil && !strings.Contains(err.Error(), "couldn't solve computational constraint") { + // // panic("no assertions, yet solving resulted in an error.") + // // } + // default: + // panic("unrecognized R1CS curve type") } + // } return 1 } diff --git a/backend/groth16/groth16.go b/backend/groth16/groth16.go index 8985e10eff..d159d523c3 100644 --- a/backend/groth16/groth16.go +++ b/backend/groth16/groth16.go @@ -24,7 +24,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" backend_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/cs" backend_bls12381 "github.com/consensys/gnark/internal/backend/bls12-381/cs" @@ -188,11 +188,12 @@ func ReadAndVerify(proof Proof, vk VerifyingKey, publicWitness io.Reader) error // will executes all the prover computations, even if the witness is invalid // will produce an invalid proof // internally, the solution vector to the R1CS will be filled with random values which may impact benchmarking -func Prove(r1cs frontend.CompiledConstraintSystem, pk ProvingKey, witness frontend.Circuit, hintFunctions []hint.Function, force ...bool) (Proof, error) { +func Prove(r1cs frontend.CompiledConstraintSystem, pk ProvingKey, witness frontend.Circuit, opts ...func(opt *backend.ProverOption) error) (Proof, error) { - _force := false - if len(force) > 0 { - _force = force[0] + // apply options + opt, err := backend.NewProverOption(opts...) + if err != nil { + return nil, err } switch _r1cs := r1cs.(type) { @@ -201,31 +202,31 @@ func Prove(r1cs frontend.CompiledConstraintSystem, pk ProvingKey, witness fronte if err := w.FromFullAssignment(witness); err != nil { return nil, err } - return groth16_bls12377.Prove(_r1cs, pk.(*groth16_bls12377.ProvingKey), w, hintFunctions, _force) + return groth16_bls12377.Prove(_r1cs, pk.(*groth16_bls12377.ProvingKey), w, opt) case *backend_bls12381.R1CS: w := witness_bls12381.Witness{} if err := w.FromFullAssignment(witness); err != nil { return nil, err } - return groth16_bls12381.Prove(_r1cs, pk.(*groth16_bls12381.ProvingKey), w, hintFunctions, _force) + return groth16_bls12381.Prove(_r1cs, pk.(*groth16_bls12381.ProvingKey), w, opt) case *backend_bn254.R1CS: w := witness_bn254.Witness{} if err := w.FromFullAssignment(witness); err != nil { return nil, err } - return groth16_bn254.Prove(_r1cs, pk.(*groth16_bn254.ProvingKey), w, hintFunctions, _force) + return groth16_bn254.Prove(_r1cs, pk.(*groth16_bn254.ProvingKey), w, opt) case *backend_bw6761.R1CS: w := witness_bw6761.Witness{} if err := w.FromFullAssignment(witness); err != nil { return nil, err } - return groth16_bw6761.Prove(_r1cs, pk.(*groth16_bw6761.ProvingKey), w, hintFunctions, _force) + return groth16_bw6761.Prove(_r1cs, pk.(*groth16_bw6761.ProvingKey), w, opt) case *backend_bls24315.R1CS: w := witness_bls24315.Witness{} if err := w.FromFullAssignment(witness); err != nil { return nil, err } - return groth16_bls24315.Prove(_r1cs, pk.(*groth16_bls24315.ProvingKey), w, hintFunctions, _force) + return groth16_bls24315.Prove(_r1cs, pk.(*groth16_bls24315.ProvingKey), w, opt) default: panic("unrecognized R1CS curve type") } @@ -234,10 +235,12 @@ func Prove(r1cs frontend.CompiledConstraintSystem, pk ProvingKey, witness fronte // ReadAndProve behaves like Prove, , except witness is read from a io.Reader // witness must be encoded following the binary serialization protocol described in // gnark/backend/witness package -func ReadAndProve(r1cs frontend.CompiledConstraintSystem, pk ProvingKey, witness io.Reader, hintFunctions []hint.Function, force ...bool) (Proof, error) { - _force := false - if len(force) > 0 { - _force = force[0] +func ReadAndProve(r1cs frontend.CompiledConstraintSystem, pk ProvingKey, witness io.Reader, opts ...func(opt *backend.ProverOption) error) (Proof, error) { + + // apply options + opt, err := backend.NewProverOption(opts...) + if err != nil { + return nil, err } _, nbSecret, nbPublic := r1cs.GetNbVariables() @@ -249,31 +252,31 @@ func ReadAndProve(r1cs frontend.CompiledConstraintSystem, pk ProvingKey, witness if _, err := w.LimitReadFrom(witness, expectedSize); err != nil { return nil, err } - return groth16_bls12377.Prove(_r1cs, pk.(*groth16_bls12377.ProvingKey), w, hintFunctions, _force) + return groth16_bls12377.Prove(_r1cs, pk.(*groth16_bls12377.ProvingKey), w, opt) case *backend_bls12381.R1CS: w := witness_bls12381.Witness{} if _, err := w.LimitReadFrom(witness, expectedSize); err != nil { return nil, err } - return groth16_bls12381.Prove(_r1cs, pk.(*groth16_bls12381.ProvingKey), w, hintFunctions, _force) + return groth16_bls12381.Prove(_r1cs, pk.(*groth16_bls12381.ProvingKey), w, opt) case *backend_bn254.R1CS: w := witness_bn254.Witness{} if _, err := w.LimitReadFrom(witness, expectedSize); err != nil { return nil, err } - return groth16_bn254.Prove(_r1cs, pk.(*groth16_bn254.ProvingKey), w, hintFunctions, _force) + return groth16_bn254.Prove(_r1cs, pk.(*groth16_bn254.ProvingKey), w, opt) case *backend_bw6761.R1CS: w := witness_bw6761.Witness{} if _, err := w.LimitReadFrom(witness, expectedSize); err != nil { return nil, err } - return groth16_bw6761.Prove(_r1cs, pk.(*groth16_bw6761.ProvingKey), w, hintFunctions, _force) + return groth16_bw6761.Prove(_r1cs, pk.(*groth16_bw6761.ProvingKey), w, opt) case *backend_bls24315.R1CS: w := witness_bls24315.Witness{} if _, err := w.LimitReadFrom(witness, expectedSize); err != nil { return nil, err } - return groth16_bls24315.Prove(_r1cs, pk.(*groth16_bls24315.ProvingKey), w, hintFunctions, _force) + return groth16_bls24315.Prove(_r1cs, pk.(*groth16_bls24315.ProvingKey), w, opt) default: panic("unrecognized R1CS curve type") } diff --git a/backend/plonk/assert.go b/backend/plonk/assert.go index cd60a7d583..6a3419bfd6 100644 --- a/backend/plonk/assert.go +++ b/backend/plonk/assert.go @@ -20,7 +20,7 @@ import ( "github.com/stretchr/testify/require" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" cs_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/cs" @@ -54,7 +54,7 @@ func NewAssert(t *testing.T) *Assert { return &Assert{require.New(t)} } -func (assert *Assert) ProverSucceeded(ccs frontend.CompiledConstraintSystem, witness frontend.Circuit, hintFunctions ...hint.Function) { +func (assert *Assert) ProverSucceeded(ccs frontend.CompiledConstraintSystem, witness frontend.Circuit, opts ...func(opt *backend.ProverOption) error) { // checks if the system is solvable assert.SolvingSucceeded(ccs, witness) @@ -66,7 +66,7 @@ func (assert *Assert) ProverSucceeded(ccs frontend.CompiledConstraintSystem, wit assert.NoError(err, "Generating public data should not have failed") // generates the proof - proof, err := Prove(ccs, pk, witness, hintFunctions) + proof, err := Prove(ccs, pk, witness, opts...) assert.NoError(err, "Proving with good witness should not output an error") // verifies the proof @@ -75,7 +75,7 @@ func (assert *Assert) ProverSucceeded(ccs frontend.CompiledConstraintSystem, wit } -func (assert *Assert) ProverFailed(ccs frontend.CompiledConstraintSystem, witness frontend.Circuit, hintFunctions ...hint.Function) { +func (assert *Assert) ProverFailed(ccs frontend.CompiledConstraintSystem, witness frontend.Circuit, opts ...func(opt *backend.ProverOption) error) { // generates public data srs, err := newKZGSrs(ccs) @@ -84,7 +84,7 @@ func (assert *Assert) ProverFailed(ccs frontend.CompiledConstraintSystem, witnes assert.NoError(err, "Generating public data should not have failed") // generates the proof - _, err = Prove(ccs, pk, witness, hintFunctions) + _, err = Prove(ccs, pk, witness, opts...) assert.Error(err, "generating an incorrect proof should output an error") } @@ -100,38 +100,44 @@ func (assert *Assert) SolvingFailed(ccs frontend.CompiledConstraintSystem, witne // IsSolved attempts to solve the constraint system with provided witness // returns nil if it succeeds, error otherwise. -func IsSolved(ccs frontend.CompiledConstraintSystem, witness frontend.Circuit, hintFunctions ...hint.Function) error { +func IsSolved(ccs frontend.CompiledConstraintSystem, witness frontend.Circuit, opts ...func(opt *backend.ProverOption) error) error { + + opt, err := backend.NewProverOption(opts...) + if err != nil { + return err + } + switch tccs := ccs.(type) { case *cs_bn254.SparseR1CS: w := witness_bn254.Witness{} if err := w.FromFullAssignment(witness); err != nil { return err } - return tccs.IsSolved(w, hintFunctions) + return tccs.IsSolved(w, opt) case *cs_bls12381.SparseR1CS: w := witness_bls12381.Witness{} if err := w.FromFullAssignment(witness); err != nil { return err } - return tccs.IsSolved(w, hintFunctions) + return tccs.IsSolved(w, opt) case *cs_bls12377.SparseR1CS: w := witness_bls12377.Witness{} if err := w.FromFullAssignment(witness); err != nil { return err } - return tccs.IsSolved(w, hintFunctions) + return tccs.IsSolved(w, opt) case *cs_bw6761.SparseR1CS: w := witness_bw6761.Witness{} if err := w.FromFullAssignment(witness); err != nil { return err } - return tccs.IsSolved(w, hintFunctions) + return tccs.IsSolved(w, opt) case *cs_bls24315.SparseR1CS: w := witness_bls24315.Witness{} if err := w.FromFullAssignment(witness); err != nil { return err } - return tccs.IsSolved(w, hintFunctions) + return tccs.IsSolved(w, opt) default: panic("unknown constraint system type") } diff --git a/backend/plonk/plonk.go b/backend/plonk/plonk.go index 0437075fc8..2f456b9d24 100644 --- a/backend/plonk/plonk.go +++ b/backend/plonk/plonk.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/kzg" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" cs_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/cs" @@ -158,11 +158,12 @@ func Setup(ccs frontend.CompiledConstraintSystem, kzgSRS kzg.SRS) (ProvingKey, V // will executes all the prover computations, even if the witness is invalid // will produce an invalid proof // internally, the solution vector to the SparseR1CS will be filled with random values which may impact benchmarking -func Prove(ccs frontend.CompiledConstraintSystem, pk ProvingKey, fullWitness frontend.Circuit, hintFunctions []hint.Function, force ...bool) (Proof, error) { +func Prove(ccs frontend.CompiledConstraintSystem, pk ProvingKey, fullWitness frontend.Circuit, opts ...func(opt *backend.ProverOption) error) (Proof, error) { - _force := false - if len(force) > 0 { - _force = force[0] + // apply options + opt, err := backend.NewProverOption(opts...) + if err != nil { + return nil, err } switch tccs := ccs.(type) { @@ -171,35 +172,35 @@ func Prove(ccs frontend.CompiledConstraintSystem, pk ProvingKey, fullWitness fro if err := w.FromFullAssignment(fullWitness); err != nil { return nil, err } - return plonk_bn254.Prove(tccs, pk.(*plonk_bn254.ProvingKey), w, hintFunctions, _force) + return plonk_bn254.Prove(tccs, pk.(*plonk_bn254.ProvingKey), w, opt) case *cs_bls12381.SparseR1CS: w := witness_bls12381.Witness{} if err := w.FromFullAssignment(fullWitness); err != nil { return nil, err } - return plonk_bls12381.Prove(tccs, pk.(*plonk_bls12381.ProvingKey), w, hintFunctions, _force) + return plonk_bls12381.Prove(tccs, pk.(*plonk_bls12381.ProvingKey), w, opt) case *cs_bls12377.SparseR1CS: w := witness_bls12377.Witness{} if err := w.FromFullAssignment(fullWitness); err != nil { return nil, err } - return plonk_bls12377.Prove(tccs, pk.(*plonk_bls12377.ProvingKey), w, hintFunctions, _force) + return plonk_bls12377.Prove(tccs, pk.(*plonk_bls12377.ProvingKey), w, opt) case *cs_bw6761.SparseR1CS: w := witness_bw6761.Witness{} if err := w.FromFullAssignment(fullWitness); err != nil { return nil, err } - return plonk_bw6761.Prove(tccs, pk.(*plonk_bw6761.ProvingKey), w, hintFunctions, _force) + return plonk_bw6761.Prove(tccs, pk.(*plonk_bw6761.ProvingKey), w, opt) case *cs_bls24315.SparseR1CS: w := witness_bls24315.Witness{} if err := w.FromFullAssignment(fullWitness); err != nil { return nil, err } - return plonk_bls24315.Prove(tccs, pk.(*plonk_bls24315.ProvingKey), w, hintFunctions, _force) + return plonk_bls24315.Prove(tccs, pk.(*plonk_bls24315.ProvingKey), w, opt) default: panic("unrecognized SparseR1CS curve type") @@ -339,11 +340,12 @@ func NewVerifyingKey(curveID ecc.ID) VerifyingKey { } // ReadAndProve generates PLONK proof from a circuit, associated proving key, and the full witness -func ReadAndProve(ccs frontend.CompiledConstraintSystem, pk ProvingKey, witness io.Reader, hintFunctions []hint.Function, force ...bool) (Proof, error) { +func ReadAndProve(ccs frontend.CompiledConstraintSystem, pk ProvingKey, witness io.Reader, opts ...func(opt *backend.ProverOption) error) (Proof, error) { - _force := false - if len(force) > 0 { - _force = force[0] + // apply options + opt, err := backend.NewProverOption(opts...) + if err != nil { + return nil, err } _, nbSecret, nbPublic := ccs.GetNbVariables() @@ -356,7 +358,7 @@ func ReadAndProve(ccs frontend.CompiledConstraintSystem, pk ProvingKey, witness if _, err := w.LimitReadFrom(witness, expectedSize); err != nil { return nil, err } - proof, err := plonk_bn254.Prove(tccs, _pk, w, hintFunctions, _force) + proof, err := plonk_bn254.Prove(tccs, _pk, w, opt) if err != nil { return proof, err } @@ -368,7 +370,7 @@ func ReadAndProve(ccs frontend.CompiledConstraintSystem, pk ProvingKey, witness if _, err := w.LimitReadFrom(witness, expectedSize); err != nil { return nil, err } - proof, err := plonk_bls12381.Prove(tccs, _pk, w, hintFunctions, _force) + proof, err := plonk_bls12381.Prove(tccs, _pk, w, opt) if err != nil { return proof, err } @@ -380,7 +382,7 @@ func ReadAndProve(ccs frontend.CompiledConstraintSystem, pk ProvingKey, witness if _, err := w.LimitReadFrom(witness, expectedSize); err != nil { return nil, err } - proof, err := plonk_bls12377.Prove(tccs, _pk, w, hintFunctions, _force) + proof, err := plonk_bls12377.Prove(tccs, _pk, w, opt) if err != nil { return proof, err } @@ -392,7 +394,7 @@ func ReadAndProve(ccs frontend.CompiledConstraintSystem, pk ProvingKey, witness if _, err := w.LimitReadFrom(witness, expectedSize); err != nil { return nil, err } - proof, err := plonk_bw6761.Prove(tccs, _pk, w, hintFunctions, _force) + proof, err := plonk_bw6761.Prove(tccs, _pk, w, opt) if err != nil { return proof, err } @@ -404,7 +406,7 @@ func ReadAndProve(ccs frontend.CompiledConstraintSystem, pk ProvingKey, witness if _, err := w.LimitReadFrom(witness, expectedSize); err != nil { return nil, err } - proof, err := plonk_bls24315.Prove(tccs, _pk, w, hintFunctions, _force) + proof, err := plonk_bls24315.Prove(tccs, _pk, w, opt) if err != nil { return proof, err } diff --git a/debug_test.go b/debug_test.go new file mode 100644 index 0000000000..ee8211073a --- /dev/null +++ b/debug_test.go @@ -0,0 +1,206 @@ +package gnark + +import ( + "bytes" + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/backend/plonk" + "github.com/consensys/gnark/frontend" + "github.com/stretchr/testify/require" +) + +// ------------------------------------------------------------------------------------------------- +// test println (non regression) +type printlnCircuit struct { + A, B frontend.Variable +} + +func (circuit *printlnCircuit) Define(curveID ecc.ID, cs *frontend.ConstraintSystem) error { + c := cs.Add(circuit.A, circuit.B) + cs.Println(c, "is the addition") + d := cs.Mul(circuit.A, c) + cs.Println(d, new(big.Int).SetInt64(42)) + bs := cs.ToBinary(circuit.B, 10) + cs.Println("bits", bs[3]) + cs.Println("circuit", circuit) + cs.AssertIsBoolean(cs.Constant(10)) // this will fail + m := cs.Mul(circuit.A, circuit.B) + cs.Println("m", m) // this should not be resolved + return nil +} + +func TestPrintln(t *testing.T) { + assert := require.New(t) + + var circuit, witness printlnCircuit + witness.A.Assign(2) + witness.B.Assign(11) + + var expected bytes.Buffer + expected.WriteString("debug_test.go:24 13 is the addition\n") + expected.WriteString("debug_test.go:26 26 42\n") + expected.WriteString("debug_test.go:28 bits 1\n") + expected.WriteString("debug_test.go:29 circuit {A: 2, B: 11}\n") + expected.WriteString("debug_test.go:32 m \n") + + { + trace, _ := getGroth16Trace(&circuit, &witness) + assert.Equal(trace, expected.String()) + } + + { + trace, _ := getPlonkTrace(&circuit, &witness) + assert.Equal(trace, expected.String()) + } +} + +// ------------------------------------------------------------------------------------------------- +// Div by 0 +type divBy0Trace struct { + A, B, C frontend.Variable +} + +func (circuit *divBy0Trace) Define(curveID ecc.ID, cs *frontend.ConstraintSystem) error { + d := cs.Add(circuit.B, circuit.C) + cs.Div(circuit.A, d) + return nil +} + +func TestTraceDivBy0(t *testing.T) { + assert := require.New(t) + + var circuit, witness divBy0Trace + witness.A.Assign(2) + witness.B.Assign(-2) + witness.C.Assign(2) + + { + _, err := getGroth16Trace(&circuit, &witness) + assert.Error(err) + assert.Contains(err.Error(), "constraint is not satisfied: [div] 2/(-2 + 2) == 0") + assert.Contains(err.Error(), "gnark.(*divBy0Trace).Define") + assert.Contains(err.Error(), "debug_test.go:69") + } + + { + _, err := getPlonkTrace(&circuit, &witness) + assert.Error(err) + assert.Contains(err.Error(), "constraint is not satisfied: [div] 2/(-2 + 2) == 0") + assert.Contains(err.Error(), "gnark.(*divBy0Trace).Define") + assert.Contains(err.Error(), "debug_test.go:69") + } +} + +// ------------------------------------------------------------------------------------------------- +// Not Equal +type notEqualTrace struct { + A, B, C frontend.Variable +} + +func (circuit *notEqualTrace) Define(curveID ecc.ID, cs *frontend.ConstraintSystem) error { + d := cs.Add(circuit.B, circuit.C) + cs.AssertIsEqual(circuit.A, d) + return nil +} + +func TestTraceNotEqual(t *testing.T) { + assert := require.New(t) + + var circuit, witness notEqualTrace + witness.A.Assign(1) + witness.B.Assign(24) + witness.C.Assign(42) + + { + _, err := getGroth16Trace(&circuit, &witness) + assert.Error(err) + assert.Contains(err.Error(), "constraint is not satisfied: [assertIsEqual] 1 == (24 + 42)") + assert.Contains(err.Error(), "gnark.(*notEqualTrace).Define") + assert.Contains(err.Error(), "debug_test.go:106") + } + + { + _, err := getPlonkTrace(&circuit, &witness) + assert.Error(err) + assert.Contains(err.Error(), "constraint is not satisfied: [assertIsEqual] 1 == (24 + 42)") + assert.Contains(err.Error(), "gnark.(*notEqualTrace).Define") + assert.Contains(err.Error(), "debug_test.go:106") + } +} + +// ------------------------------------------------------------------------------------------------- +// Not boolean +type notBooleanTrace struct { + A, B, C frontend.Variable +} + +func (circuit *notBooleanTrace) Define(curveID ecc.ID, cs *frontend.ConstraintSystem) error { + d := cs.Add(circuit.B, circuit.C) + cs.AssertIsBoolean(d) + return nil +} + +func TestTraceNotBoolean(t *testing.T) { + assert := require.New(t) + + var circuit, witness notBooleanTrace + witness.A.Assign(1) + witness.B.Assign(24) + witness.C.Assign(42) + + { + _, err := getGroth16Trace(&circuit, &witness) + assert.Error(err) + assert.Contains(err.Error(), "constraint is not satisfied: [assertIsBoolean] (24 + 42) == (0|1)") + assert.Contains(err.Error(), "gnark.(*notBooleanTrace).Define") + assert.Contains(err.Error(), "debug_test.go:143") + } + + { + _, err := getPlonkTrace(&circuit, &witness) + assert.Error(err) + assert.Contains(err.Error(), "constraint is not satisfied: [assertIsBoolean] (24 + 42) == (0|1)") + assert.Contains(err.Error(), "gnark.(*notBooleanTrace).Define") + assert.Contains(err.Error(), "debug_test.go:143") + } +} + +func getPlonkTrace(circuit, witness frontend.Circuit) (string, error) { + ccs, err := frontend.Compile(ecc.BN254, backend.PLONK, circuit) + if err != nil { + return "", err + } + + srs, err := plonk.NewSRS(ccs) + if err != nil { + return "", err + } + pk, _, err := plonk.Setup(ccs, srs) + if err != nil { + return "", err + } + + var buf bytes.Buffer + _, err = plonk.Prove(ccs, pk, witness, backend.WithOutput(&buf)) + return buf.String(), err +} + +func getGroth16Trace(circuit, witness frontend.Circuit) (string, error) { + ccs, err := frontend.Compile(ecc.BN254, backend.GROTH16, circuit) + if err != nil { + return "", err + } + + pk, err := groth16.DummySetup(ccs) + if err != nil { + return "", err + } + + var buf bytes.Buffer + _, err = groth16.Prove(ccs, pk, witness, backend.WithOutput(&buf)) + return buf.String(), err +} diff --git a/frontend/cs.go b/frontend/cs.go index b7f0ad6a89..56a8a1cdfa 100644 --- a/frontend/cs.go +++ b/frontend/cs.go @@ -17,15 +17,9 @@ limitations under the License. package frontend import ( - "fmt" "io" "math/big" - "path/filepath" - "reflect" - "runtime" "sort" - "strconv" - "strings" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend/hint" @@ -44,22 +38,21 @@ type ConstraintSystem struct { // they may only contain a linear expression public, secret, internal, virtual variables - // Constraints - constraints []compiled.R1C // list of R1C that yield an output (for example v3 == v1 * v2, return v3) - assertions []compiled.R1C // list of R1C that yield no output (for example ensuring v1 == v2) + // list of constraints in the form a * b == c + // a,b and c being linear expressions + constraints []compiled.R1C // Coefficients in the constraints coeffs []big.Int // list of unique coefficients. coeffsIDs map[string]int // map to fast check existence of a coefficient (key = coeff.Text(16)) // Hints - hints []compiled.Hint // solver hints + mHints map[int]compiled.Hint // solver hints - // debug info - logs []logEntry // list of logs to be printed when solving a circuit. The logs are called with the method Println - debugInfoComputation []logEntry // list of logs storing information about computations (e.g. division by 0).If an computation fails, it prints it in a friendly format - debugInfoAssertion []logEntry // list of logs storing information about assertions. If an assertion fails, it prints it in a friendly format + logs []compiled.LogEntry // list of logs to be printed when solving a circuit. The logs are called with the method Println + debugInfo []compiled.LogEntry // list of logs storing information about R1C + mDebug map[int]int // maps constraint ID to debugInfo id } type variables struct { @@ -86,9 +79,6 @@ type CompiledConstraintSystem interface { GetNbConstraints() int GetNbCoefficients() int - // SetLoggerOutput replace existing logger output with provided one - SetLoggerOutput(w io.Writer) - CurveID() ecc.ID FrSize() int @@ -107,7 +97,8 @@ func newConstraintSystem(initialCapacity ...int) ConstraintSystem { coeffs: make([]big.Int, 4), coeffsIDs: make(map[string]int), constraints: make([]compiled.R1C, 0, capacity), - assertions: make([]compiled.R1C, 0), + mDebug: make(map[int]int), + mHints: make(map[int]compiled.Hint), } cs.coeffs[compiled.CoeffIdZero].SetInt64(0) @@ -130,8 +121,6 @@ func newConstraintSystem(initialCapacity ...int) ConstraintSystem { // by default the circuit is given on public wire equal to 1 cs.public.variables[0] = cs.newPublicVariable() - cs.hints = make([]compiled.Hint, 0) - return cs } @@ -158,89 +147,11 @@ func (cs *ConstraintSystem) NewHint(hintID hint.ID, inputs ...interface{}) Varia } // add the hint to the constraint system - cs.hints = append(cs.hints, compiled.Hint{WireID: r.id, ID: hintID, Inputs: hintInputs}) + cs.mHints[r.id] = compiled.Hint{ID: hintID, Inputs: hintInputs} return r } -// Println enables circuit debugging and behaves almost like fmt.Println() -// -// the print will be done once the R1CS.Solve() method is executed -// -// if one of the input is a Variable, its value will be resolved avec R1CS.Solve() method is called -func (cs *ConstraintSystem) Println(a ...interface{}) { - var sbb strings.Builder - - // prefix log line with file.go:line - if _, file, line, ok := runtime.Caller(1); ok { - sbb.WriteString(filepath.Base(file)) - sbb.WriteByte(':') - sbb.WriteString(strconv.Itoa(line)) - sbb.WriteByte(' ') - } - - // for each argument, if it is a circuit structure and contains variable - // we add the variables in the logEntry.toResolve part, and add %s to the format string in the log entry - // if it doesn't contain variable, call fmt.Sprint(arg) instead - entry := logEntry{} - - // this is call recursively on the arguments using reflection on each argument - foundVariable := false - - var handler logValueHandler = func(name string, tInput reflect.Value) { - - v := tInput.Interface().(Variable) - - // if the variable is only in linExp form, we allocate it - _v := cs.allocate(v) - - entry.toResolve = append(entry.toResolve, compiled.Pack(_v.id, 0, _v.visibility)) - - if name == "" { - sbb.WriteString("%s") - } else { - sbb.WriteString(fmt.Sprintf("%s: %%s ", name)) - } - - foundVariable = true - } - - for i, arg := range a { - if i > 0 { - sbb.WriteByte(' ') - } - foundVariable = false - parseLogValue(arg, "", handler) - if !foundVariable { - sbb.WriteString(fmt.Sprint(arg)) - } - } - sbb.WriteByte('\n') - - // set format string to be used with fmt.Sprintf, once the variables are solved in the R1CS.Solve() method - entry.format = sbb.String() - - cs.logs = append(cs.logs, entry) -} - -type logEntry struct { - format string - toResolve []compiled.Term -} - -var ( - bOne = new(big.Int).SetInt64(1) -) - -// debug info in case a variable is not set -// func debugInfoUnsetVariable(term compiled.Term) logEntry { -// entry := logEntry{} -// stack := getCallStack() -// entry.format = stack[len(stack)-1] -// entry.toResolve = append(entry.toResolve, term) -// return entry -// } - func (cs *ConstraintSystem) one() Variable { return cs.public.variables[0] } @@ -268,11 +179,8 @@ func newR1C(l, r, o Variable) compiled.R1C { // NbConstraints enables circuit profiling and helps debugging // It returns the number of constraints created at the current stage of the circuit construction. -// -// The number returns included both the assertions and the non-assertion constraints -// (eg: the constraints which creates a new variable) func (cs *ConstraintSystem) NbConstraints() int { - return len(cs.constraints) + len(cs.assertions) + return len(cs.constraints) } // LinearExpression packs a list of compiled.Term in a compiled.LinearExpression and returns it. @@ -310,11 +218,6 @@ func (cs *ConstraintSystem) reduce(l compiled.LinearExpression) compiled.LinearE return l } -func (cs *ConstraintSystem) addAssertion(constraint compiled.R1C, debugInfo logEntry) { - cs.assertions = append(cs.assertions, constraint) - cs.debugInfoAssertion = append(cs.debugInfoAssertion, debugInfo) -} - // coeffID tries to fetch the entry where b is if it exits, otherwise appends b to // the list of coeffs and returns the corresponding entry func (cs *ConstraintSystem) coeffID(b *big.Int) int { @@ -349,17 +252,11 @@ func (cs *ConstraintSystem) coeffID(b *big.Int) int { return resID } -// if v is unset and linExp is non empty, the variable is allocated -// resulting in one more constraint in the system. If v is set OR v is -// unset and linexp is emppty, it does nothing. -func (cs *ConstraintSystem) allocate(v Variable) Variable { - if v.visibility == compiled.Unset && len(v.linExp) > 0 { - iv := cs.newInternalVariable() - one := cs.one() - cs.constraints = append(cs.constraints, newR1C(v, one, iv)) - return iv +func (cs *ConstraintSystem) addConstraint(r1c compiled.R1C, debugID ...int) { + cs.constraints = append(cs.constraints, r1c) + if len(debugID) > 0 { + cs.mDebug[len(cs.constraints)-1] = debugID[0] } - return v } // newInternalVariable creates a new wire, appends it on the list of wires of the circuit, sets @@ -385,104 +282,8 @@ func (cs *ConstraintSystem) newVirtualVariable() Variable { return cs.virtual.new(cs, compiled.Virtual) } -type logValueHandler func(name string, tValue reflect.Value) - -func appendName(baseName, name string) string { - if baseName == "" { - return name - } - return baseName + "_" + name -} - -func parseLogValue(input interface{}, name string, handler logValueHandler) { - tVariable := reflect.TypeOf(Variable{}) - - tValue := reflect.ValueOf(input) - if tValue.Kind() == reflect.Ptr { - tValue = tValue.Elem() - } - switch tValue.Kind() { - case reflect.Struct: - switch tValue.Type() { - case tVariable: - handler(name, tValue) - return - default: - for i := 0; i < tValue.NumField(); i++ { - if tValue.Field(i).CanInterface() { - value := tValue.Field(i).Interface() - _name := appendName(name, tValue.Type().Field(i).Name) - parseLogValue(value, _name, handler) - } - } - } - case reflect.Slice, reflect.Array: - if tValue.Len() == 0 { - fmt.Println("warning, got unitizalized slice (or empty array). Ignoring;") - return - } - for j := 0; j < tValue.Len(); j++ { - value := tValue.Index(j).Interface() - entry := "[" + strconv.Itoa(j) + "]" - _name := appendName(name, entry) - parseLogValue(value, _name, handler) - } - } -} - -// derived from: https://golang.org/pkg/runtime/#example_Frames -// we stop when func name == Define as it is where the gnark circuit code should start -func getCallStack() []string { - // Ask runtime.Callers for up to 10 pcs - pc := make([]uintptr, 10) - n := runtime.Callers(3, pc) - if n == 0 { - // No pcs available. Stop now. - // This can happen if the first argument to runtime.Callers is large. - return nil - } - pc = pc[:n] // pass only valid pcs to runtime.CallersFrames - frames := runtime.CallersFrames(pc) - // Loop to get frames. - // A fixed number of pcs can expand to an indefinite number of Frames. - var toReturn []string - for { - frame, more := frames.Next() - fe := strings.Split(frame.Function, "/") - function := fe[len(fe)-1] - toReturn = append(toReturn, fmt.Sprintf("%s\n\t%s:%d", function, frame.File, frame.Line)) - if !more { - break - } - if strings.HasSuffix(function, "Define") { - break - } - } - return toReturn -} - func (cs *ConstraintSystem) buildVarFromWire(pv Wire) Variable { - return Variable{pv, cs.LinearExpression(cs.makeTerm(pv, bOne))} -} - -// creates a string formatted to display correctly a variable, from its linear expression representation -// (i.e. the linear expression leading to it) -func (cs *ConstraintSystem) buildLogEntryFromVariable(v Variable) logEntry { - - var res logEntry - var sbb strings.Builder - sbb.Grow(len(v.linExp) * len(" + (xx + xxxxxxxxxxxx")) - - for i := 0; i < len(v.linExp); i++ { - if i > 0 { - sbb.WriteString(" + ") - } - c := cs.coeffs[v.linExp[i].CoeffID()] - sbb.WriteString(fmt.Sprintf("(%%s * %s)", c.String())) - } - res.format = sbb.String() - res.toResolve = v.linExp.Clone() - return res + return Variable{pv, cs.LinearExpression(compiled.Pack(pv.id, compiled.CoeffIdOne, pv.visibility))} } // markBoolean marks the variable as boolean and return true diff --git a/frontend/cs_api.go b/frontend/cs_api.go index 6deab0f4c9..7b370b875c 100644 --- a/frontend/cs_api.go +++ b/frontend/cs_api.go @@ -18,7 +18,6 @@ package frontend import ( "math/big" - "strings" "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/internal/backend/compiled" @@ -185,79 +184,29 @@ func (cs *ConstraintSystem) Mul(i1, i2 interface{}, in ...interface{}) Variable // Inverse returns res = inverse(v) func (cs *ConstraintSystem) Inverse(v Variable) Variable { - v.assertIsSet() // allocate resulting variable res := cs.newInternalVariable() - cs.constraints = append(cs.constraints, newR1C(v, res, cs.one())) - - // prepare debug info to be displayed in case the constraint is not solved - debugInfo := logEntry{ - toResolve: nil, - } - var sbb strings.Builder - sbb.WriteString("couldn't solve computational constraint (inversion by zero ?)") - stack := getCallStack() - for i := 0; i < len(stack); i++ { - sbb.WriteByte('\n') - sbb.WriteString(stack[i]) - } - debugInfo.format = sbb.String() + debug := cs.addDebugInfo("inverse", v, "*", res, " == 1") - // add it to the logs record - cs.debugInfoComputation = append(cs.debugInfoComputation, debugInfo) + cs.addConstraint(newR1C(v, res, cs.one()), debug) return res } // Div returns res = i1 / i2 func (cs *ConstraintSystem) Div(i1, i2 interface{}) Variable { - // allocate resulting variable res := cs.newInternalVariable() - // O - switch t1 := i1.(type) { - case Variable: - t1.assertIsSet() - switch t2 := i2.(type) { - case Variable: - t2.assertIsSet() - cs.constraints = append(cs.constraints, newR1C(t2, res, t1)) - default: - tmp := cs.Constant(t2) - cs.constraints = append(cs.constraints, newR1C(res, tmp, t1)) - } - default: - switch t2 := i2.(type) { - case Variable: - t2.assertIsSet() - tmp := cs.Constant(t1) - cs.constraints = append(cs.constraints, newR1C(t2, res, tmp)) - default: - tmp1 := cs.Constant(t1) - tmp2 := cs.Constant(t2) - cs.constraints = append(cs.constraints, newR1C(res, tmp2, tmp1)) - } - } + v1 := cs.Constant(i1) + v2 := cs.Constant(i2) - // prepare debug info to be displayed in case the constraint is not solved - debugInfo := logEntry{ - toResolve: nil, - } - var sbb strings.Builder - sbb.WriteString("couldn't solve computational constraint (inversion by zero ?)") - stack := getCallStack() - for i := 0; i < len(stack); i++ { - sbb.WriteByte('\n') - sbb.WriteString(stack[i]) - } - debugInfo.format = sbb.String() + debug := cs.addDebugInfo("div", v1, "/", v2, " == ", res) - // add it to the logs record - cs.debugInfoComputation = append(cs.debugInfoComputation, debugInfo) + cs.addConstraint(newR1C(v2, res, v1), debug) return res } @@ -316,6 +265,7 @@ func (cs *ConstraintSystem) And(a, b Variable) Variable { // IsZero returns 1 if a is zero, 0 otherwise func (cs *ConstraintSystem) IsZero(a Variable) Variable { a.assertIsSet() + debug := cs.addDebugInfo("isZero", a) //m * (1 - m) = 0 // constrain m to be 0 or 1 // a * m = 0 // constrain m to be 0 if a != 0 @@ -323,7 +273,7 @@ func (cs *ConstraintSystem) IsZero(a Variable) Variable { // m is computed by the solver such that m = 1 - a^(modulus - 1) m := cs.NewHint(hint.IsZero, a) - cs.constraints = append(cs.constraints, newR1C(a, m, cs.Constant(0))) + cs.addConstraint(newR1C(a, m, cs.Constant(0)), debug) cs.AssertIsBoolean(m) ma := cs.Add(m, a) @@ -349,7 +299,7 @@ func (cs *ConstraintSystem) ToBinary(a Variable, nbBits int) []Variable { // here what we do is we add a single constraint where // Σ (2**i * b[i]) == a var c big.Int - c.Set(bOne) + c.SetUint64(1) var Σbi Variable Σbi.linExp = make(compiled.LinearExpression, nbBits) @@ -378,7 +328,7 @@ func (cs *ConstraintSystem) FromBinary(b ...Variable) Variable { res = cs.Constant(0) // no constraint is recorded var c big.Int - c.Set(bOne) + c.SetUint64(1) L := make(compiled.LinearExpression, len(b)) for i := 0; i < len(L); i++ { diff --git a/frontend/cs_api_test.go b/frontend/cs_api_test.go index 69b5497b6d..ddd23060f7 100644 --- a/frontend/cs_api_test.go +++ b/frontend/cs_api_test.go @@ -19,7 +19,6 @@ type csState struct { nbSecretVariables int nbInternalVariables int nbConstraints int - nbAssertions int } // deltaState holds the difference between the next state (after calling a function from the API) and the previous one @@ -38,7 +37,7 @@ type nextstatefunc func(state commands.State) commands.State // the names of the public/secret inputs are variableName.String() func incVariableName() { - variableName.Add(&variableName, bOne) + variableName.Add(&variableName, new(big.Int).SetUint64(1)) } // ------------------------------------------------------------------------------ @@ -103,8 +102,7 @@ func postConditionAPI(state commands.State, result commands.Result) *gopter.Prop if len(csRes.cs.public.variables) != st.nbPublicVariables || len(csRes.cs.secret.variables) != st.nbSecretVariables || len(csRes.cs.internal.variables) != st.nbInternalVariables || - len(csRes.cs.constraints) != st.nbConstraints || - len(csRes.cs.assertions) != st.nbAssertions { + len(csRes.cs.constraints) != st.nbConstraints { return &gopter.PropResult{Status: gopter.PropFalse} } return &gopter.PropResult{Status: gopter.PropTrue} @@ -150,7 +148,7 @@ func rfAddSub() runfunc { return res } -var nsAddSub = deltaState{1, 2, 0, 0, 0} // ex: after calling add, we should have 1 public variable, 3 secret variables, 0 internal variable, 0 constraint more in the cs +var nsAddSub = deltaState{1, 2, 0, 0} // ex: after calling add, we should have 1 public variable, 3 secret variables, 0 internal variable, 0 constraint more in the cs // mul variables func rfMul() runfunc { @@ -182,7 +180,7 @@ func rfMul() runfunc { return res } -var nsMul = csState{1, 1, 1, 1, 0} +var nsMul = csState{1, 1, 1, 1} // inverse a variable func rfInverse() runfunc { @@ -210,7 +208,7 @@ func rfInverse() runfunc { return res } -var nsInverse = deltaState{1, 0, 1, 1, 0} +var nsInverse = deltaState{1, 0, 1, 1} // div 2 variables func rfDiv() runfunc { @@ -251,7 +249,7 @@ func rfDiv() runfunc { return res } -var nsDiv = deltaState{1, 1, 4, 4, 0} +var nsDiv = deltaState{1, 1, 4, 4} // xor between two variables func rfXor() runfunc { @@ -283,7 +281,7 @@ func rfXor() runfunc { return res } -var nsXor = deltaState{1, 1, 1, 1, 2} +var nsXor = deltaState{1, 1, 1, 3} // binary decomposition of a variable func rfToBinary() runfunc { @@ -309,7 +307,7 @@ func rfToBinary() runfunc { return res } -var nsToBinary = deltaState{1, 0, 256, 1, 256} +var nsToBinary = deltaState{1, 0, 256, 257} // select constraint betwwen variableq func rfSelect() runfunc { @@ -353,7 +351,7 @@ func rfSelect() runfunc { return res } -var nsSelect = deltaState{1, 2, 3, 3, 1} +var nsSelect = deltaState{1, 2, 3, 4} // copy of variable func rfConstant() runfunc { @@ -381,7 +379,7 @@ func rfConstant() runfunc { return res } -var nsConstant = deltaState{1, 0, 0, 0, 0} +var nsConstant = deltaState{1, 0, 0, 0} // equality between 2 variables func rfIsEqual() runfunc { @@ -416,7 +414,7 @@ func rfIsEqual() runfunc { return res } -var nsIsEqual = deltaState{1, 1, 0, 0, 2} +var nsIsEqual = deltaState{1, 1, 0, 2} // packing from binary variables func rfFromBinary() runfunc { @@ -444,7 +442,7 @@ func rfFromBinary() runfunc { return res } -var nsFromBinary = deltaState{256, 0, 0, 0, 256} +var nsFromBinary = deltaState{256, 0, 0, 256} // boolean constrain a variable func rfIsBoolean() runfunc { @@ -481,11 +479,11 @@ func rfIsBoolean() runfunc { return res } -var nsIsBoolean = deltaState{1, 1, 0, 0, 2} +var nsIsBoolean = deltaState{1, 1, 0, 2} -var nsMustBeLessOrEqVar = deltaState{1, 1, 1281, 771, 768} +var nsMustBeLessOrEqVar = deltaState{1, 1, 1281, 1539} -var nsMustBeLessOrEqConst = csState{1, 0, 257, 2, 511} // nb internal variables: 256+HW(bound), nb constraints: 1+HW(bound), nb assertions: 256+HW(^bound) +var nsMustBeLessOrEqConst = csState{1, 0, 257, 513} // nb internal variables: 256+HW(bound), nb constraints: 1+HW(bound), nb assertions: 256+HW(^bound) // ------------------------------------------------------------------------------ // build the next state function using the delta state @@ -496,7 +494,6 @@ func nextStateFunc(ds deltaState) nextstatefunc { state.(*csState).nbSecretVariables += ds.nbSecretVariables state.(*csState).nbInternalVariables += ds.nbInternalVariables state.(*csState).nbConstraints += ds.nbConstraints - state.(*csState).nbAssertions += ds.nbAssertions return state } return res @@ -679,11 +676,6 @@ func (c *isLessOrEq) Define(curveID ecc.ID, cs *ConstraintSystem) error { } func TestUnsetVariables(t *testing.T) { - // TODO unset variables with markBoolean will panic. - // doing - // var a Variable - // cs.AssertIsBoolean(a) - // will panic. mapFuncs := map[string]Circuit{ "add": &addCircuit{}, "sub": &subCircuit{}, diff --git a/frontend/cs_assertions.go b/frontend/cs_assertions.go index 9af35cd853..466af0d664 100644 --- a/frontend/cs_assertions.go +++ b/frontend/cs_assertions.go @@ -2,48 +2,25 @@ package frontend import ( "math/big" - "strings" "github.com/consensys/gnark/internal/backend/compiled" ) // AssertIsEqual adds an assertion in the constraint system (i1 == i2) func (cs *ConstraintSystem) AssertIsEqual(i1, i2 interface{}) { + // encoded i1 * 1 == i2 + // TODO do cs.Sub(i1,i2) == 0 ? - // encoded as L * R == O - // set L = i1 - // set R = 1 - // set O = i2 - - // we don't do just "cs.Sub(i1,i2)" to allow proper logging - debugInfo := logEntry{} - - l := cs.Constant(i1) // no constraint is recorded - r := cs.Constant(1) // no constraint is recorded - o := cs.Constant(i2) // no constraint is recorded - - // build log - var sbb strings.Builder - sbb.WriteString("[") - lhs := cs.buildLogEntryFromVariable(l) - sbb.WriteString(lhs.format) - debugInfo.toResolve = lhs.toResolve - sbb.WriteString(" != ") - rhs := cs.buildLogEntryFromVariable(o) - sbb.WriteString(rhs.format) - debugInfo.toResolve = append(debugInfo.toResolve, rhs.toResolve...) - sbb.WriteString("]") - - // get call stack - sbb.WriteString("error AssertIsEqual") - stack := getCallStack() - for i := 0; i < len(stack); i++ { - sbb.WriteByte('\n') - sbb.WriteString(stack[i]) + l := cs.Constant(i1) + o := cs.Constant(i2) + + if len(l.linExp) > len(o.linExp) { + l, o = o, l // maximize number of zeroes in r1cs.A } - debugInfo.format = sbb.String() - cs.addAssertion(newR1C(l, r, o), debugInfo) + debug := cs.addDebugInfo("assertIsEqual", l, " == ", o) + + cs.addConstraint(newR1C(l, cs.one(), o), debug) } // AssertIsDifferent constrain i1 and i2 to be different @@ -53,9 +30,7 @@ func (cs *ConstraintSystem) AssertIsDifferent(i1, i2 interface{}) { // AssertIsBoolean adds an assertion in the constraint system (v == 0 || v == 1) func (cs *ConstraintSystem) AssertIsBoolean(v Variable) { - v.assertIsSet() - if v.visibility == compiled.Unset { // we need to create a new wire here. vv := cs.newVirtualVariable() @@ -66,26 +41,12 @@ func (cs *ConstraintSystem) AssertIsBoolean(v Variable) { if !cs.markBoolean(v) { return // variable is already constrained } + debug := cs.addDebugInfo("assertIsBoolean", v, " == (0|1)") // ensure v * (1 - v) == 0 - - _v := cs.Sub(1, v) // no variable is recorded in the cs - o := cs.Constant(0) // no variable is recorded in the cs - - // prepare debug info to be displayed in case the constraint is not solved - debugInfo := logEntry{ - toResolve: nil, - } - var sbb strings.Builder - sbb.WriteString("error AssertIsBoolean") - stack := getCallStack() - for i := 0; i < len(stack); i++ { - sbb.WriteByte('\n') - sbb.WriteString(stack[i]) - } - debugInfo.format = sbb.String() - - cs.addAssertion(newR1C(v, _v, o), debugInfo) + _v := cs.Sub(1, v) + o := cs.Constant(0) + cs.addConstraint(newR1C(v, _v, o), debug) } // AssertIsLessOrEqual adds assertion in constraint system (v <= bound) @@ -108,31 +69,14 @@ func (cs *ConstraintSystem) AssertIsLessOrEqual(v Variable, bound interface{}) { } -func (cs *ConstraintSystem) mustBeLessOrEqVar(w, bound Variable) { - - // prepare debug info to be displayed in case the constraint is not solved - dbgInfoW := cs.buildLogEntryFromVariable(w) - dbgInfoBound := cs.buildLogEntryFromVariable(bound) - var sbb strings.Builder - var debugInfo logEntry - sbb.WriteString(dbgInfoW.format) - sbb.WriteString(" <= ") - sbb.WriteString(dbgInfoBound.format) - debugInfo.toResolve = make([]compiled.Term, len(dbgInfoW.toResolve)+len(dbgInfoBound.toResolve)) - copy(debugInfo.toResolve[:], dbgInfoW.toResolve) - copy(debugInfo.toResolve[len(dbgInfoW.toResolve):], dbgInfoBound.toResolve) - - stack := getCallStack() - for i := 0; i < len(stack); i++ { - sbb.WriteByte('\n') - sbb.WriteString(stack[i]) - } - debugInfo.format = sbb.String() +func (cs *ConstraintSystem) mustBeLessOrEqVar(v, bound Variable) { + debug := cs.addDebugInfo("mustBeLessOrEq", v, " <= ", bound) + // TODO nbBits shouldn't be here. const nbBits = 256 - binw := cs.ToBinary(w, nbBits) - binbound := cs.ToBinary(bound, nbBits) + wBits := cs.ToBinary(v, nbBits) + boundBits := cs.ToBinary(bound, nbBits) p := make([]Variable, nbBits+1) p[nbBits] = cs.Constant(1) @@ -141,41 +85,25 @@ func (cs *ConstraintSystem) mustBeLessOrEqVar(w, bound Variable) { for i := nbBits - 1; i >= 0; i-- { - p1 := cs.Mul(p[i+1], binw[i]) - p[i] = cs.Select(binbound[i], p1, p[i+1]) - t := cs.Select(binbound[i], zero, p[i+1]) + p1 := cs.Mul(p[i+1], wBits[i]) + p[i] = cs.Select(boundBits[i], p1, p[i+1]) + t := cs.Select(boundBits[i], zero, p[i+1]) l := cs.one() - l = cs.Sub(l, t) // no constraint is recorded - l = cs.Sub(l, binw[i]) // no constraint is recorded + l = cs.Sub(l, t) // no constraint is recorded + l = cs.Sub(l, wBits[i]) // no constraint is recorded - r := binw[i] + r := wBits[i] o := cs.Constant(0) // no constraint is recorded - cs.addAssertion(newR1C(l, r, o), debugInfo) + cs.addConstraint(newR1C(l, r, o), debug) } } func (cs *ConstraintSystem) mustBeLessOrEqCst(v Variable, bound big.Int) { - - // prepare debug info to be displayed in case the constraint is not solved - dbgInfoW := cs.buildLogEntryFromVariable(v) - var sbb strings.Builder - var debugInfo logEntry - sbb.WriteString(dbgInfoW.format) - sbb.WriteString(" <= ") - sbb.WriteString(bound.String()) - - debugInfo.toResolve = dbgInfoW.toResolve - - stack := getCallStack() - for i := 0; i < len(stack); i++ { - sbb.WriteByte('\n') - sbb.WriteString(stack[i]) - } - debugInfo.format = sbb.String() + debug := cs.addDebugInfo("mustBeLessOrEq", v, " <= ", cs.Constant(bound)) // TODO store those constant elsewhere (for the moment they don't depend on the base curve, but that might change) const nbBits = 256 @@ -206,7 +134,8 @@ func (cs *ConstraintSystem) mustBeLessOrEqCst(v Variable, bound big.Int) { r := vBits[(i+1)*wordSize-1-j] o := cs.Constant(0) - cs.addAssertion(newR1C(l, r, o), debugInfo) + + cs.addConstraint(newR1C(l, r, o), debug) } else { p[(i+1)*wordSize-1-j] = cs.Mul(p[(i+1)*wordSize-j], vBits[(i+1)*wordSize-1-j]) diff --git a/frontend/cs_debug.go b/frontend/cs_debug.go new file mode 100644 index 0000000000..b9fb30e9b9 --- /dev/null +++ b/frontend/cs_debug.go @@ -0,0 +1,135 @@ +package frontend + +import ( + "fmt" + "path/filepath" + "reflect" + "runtime" + "strconv" + "strings" + + "github.com/consensys/gnark/internal/backend/compiled" + "github.com/consensys/gnark/internal/parser" +) + +// Println enables circuit debugging and behaves almost like fmt.Println() +// +// the print will be done once the R1CS.Solve() method is executed +// +// if one of the input is a Variable, its value will be resolved avec R1CS.Solve() method is called +func (cs *ConstraintSystem) Println(a ...interface{}) { + var sbb strings.Builder + + // prefix log line with file.go:line + if _, file, line, ok := runtime.Caller(1); ok { + sbb.WriteString(filepath.Base(file)) + sbb.WriteByte(':') + sbb.WriteString(strconv.Itoa(line)) + sbb.WriteByte(' ') + } + + var log compiled.LogEntry + + for i, arg := range a { + if i > 0 { + sbb.WriteByte(' ') + } + if v, ok := arg.(Variable); ok { + v.assertIsSet() + + sbb.WriteString("%s") + // we set limits to the linear expression, so that the log printer + // can evaluate it before printing it + log.ToResolve = append(log.ToResolve, compiled.TermDelimitor) + log.ToResolve = append(log.ToResolve, v.linExp...) + log.ToResolve = append(log.ToResolve, compiled.TermDelimitor) + } else { + printArg(&log, &sbb, arg) + } + } + sbb.WriteByte('\n') + + // set format string to be used with fmt.Sprintf, once the variables are solved in the R1CS.Solve() method + log.Format = sbb.String() + + cs.logs = append(cs.logs, log) +} + +func printArg(log *compiled.LogEntry, sbb *strings.Builder, a interface{}) { + + count := 0 + counter := func(visibility compiled.Visibility, name string, tValue reflect.Value) error { + count++ + return nil + } + // ignoring error, counter() always return nil + _ = parser.Visit(a, "", compiled.Unset, counter, reflect.TypeOf(Variable{})) + + // no variables in nested struct, we use fmt std print function + if count == 0 { + sbb.WriteString(fmt.Sprint(a)) + return + } + + sbb.WriteByte('{') + printer := func(visibility compiled.Visibility, name string, tValue reflect.Value) error { + count-- + sbb.WriteString(name) + sbb.WriteString(": ") + sbb.WriteString("%s") + if count != 0 { + sbb.WriteString(", ") + } + + v := tValue.Interface().(Variable) + // we set limits to the linear expression, so that the log printer + // can evaluate it before printing it + log.ToResolve = append(log.ToResolve, compiled.TermDelimitor) + log.ToResolve = append(log.ToResolve, v.linExp...) + log.ToResolve = append(log.ToResolve, compiled.TermDelimitor) + return nil + } + // ignoring error, printer() doesn't return errors + _ = parser.Visit(a, "", compiled.Unset, printer, reflect.TypeOf(Variable{})) + sbb.WriteByte('}') +} + +func (cs *ConstraintSystem) addDebugInfo(errName string, i ...interface{}) int { + var debug compiled.LogEntry + + // TODO @gbotrel reserve capacity for the string builder + const minLogSize = 500 + var sbb strings.Builder + sbb.Grow(minLogSize) + sbb.WriteString("[") + sbb.WriteString(errName) + sbb.WriteString("] ") + + for _, _i := range i { + switch v := _i.(type) { + case Variable: + if len(v.linExp) > 1 { + sbb.WriteString("(") + } + debug.WriteLinearExpression(v.linExp, &sbb) + if len(v.linExp) > 1 { + sbb.WriteString(")") + } + + case string: + sbb.WriteString(v) + case int: + sbb.WriteString(strconv.Itoa(v)) + case compiled.Term: + debug.WriteTerm(v, &sbb) + default: + panic("unsupported log type") + } + } + sbb.WriteByte('\n') + debug.WriteStack(&sbb) + debug.Format = sbb.String() + + cs.debugInfo = append(cs.debugInfo, debug) + return len(cs.debugInfo) - 1 +} diff --git a/frontend/cs_to_r1cs.go b/frontend/cs_to_r1cs.go index 37978458f2..42913c0fbc 100644 --- a/frontend/cs_to_r1cs.go +++ b/frontend/cs_to_r1cs.go @@ -18,25 +18,35 @@ func (cs *ConstraintSystem) toR1CS(curveID ecc.ID) (CompiledConstraintSystem, er // setting up the result res := compiled.R1CS{ - NbInternalVariables: len(cs.internal.variables), - NbPublicVariables: len(cs.public.variables), - NbSecretVariables: len(cs.secret.variables), - NbConstraints: len(cs.constraints) + len(cs.assertions), - Constraints: make([]compiled.R1C, len(cs.constraints)+len(cs.assertions)), - Logs: make([]compiled.LogEntry, len(cs.logs)), - DebugInfoComputation: make([]compiled.LogEntry, len(cs.debugInfoComputation)+len(cs.debugInfoAssertion)), - Hints: make([]compiled.Hint, len(cs.hints)), + CS: compiled.CS{ + NbInternalVariables: len(cs.internal.variables), + NbPublicVariables: len(cs.public.variables), + NbSecretVariables: len(cs.secret.variables), + DebugInfo: make([]compiled.LogEntry, len(cs.debugInfo)), + Logs: make([]compiled.LogEntry, len(cs.logs)), + MHints: make(map[int]compiled.Hint, len(cs.mHints)), + MDebug: make(map[int]int), + }, + Constraints: make([]compiled.R1C, len(cs.constraints)), } + // for logs, debugInfo and hints the only thing that will change + // is that ID of the wires will be offseted to take into account the final wire vector ordering + // that is: public wires | secret wires | internal wires + // computational constraints (= gates) - copy(res.Constraints, cs.constraints) - copy(res.Constraints[len(cs.constraints):], cs.assertions) + for i, r1c := range cs.constraints { + res.Constraints[i] = compiled.R1C{ + L: r1c.L.Clone(), + R: r1c.R.Clone(), + O: r1c.O.Clone(), + } + } - // note: verbose, but we offset the IDs of the wires where they appear, that is, - // in the logs, debug info, constraints and hints - // since we don't use pointers but Terms (uint64), we need to potentially offset - // the same wireID multiple times. - copy(res.Hints, cs.hints) + // for a R1CS, the correspondance between constraint and debug info won't change, we just copy + for k, v := range cs.mDebug { + res.MDebug[k] = v + } // offset variable ID depeneding on visibility shiftVID := func(oldID int, visibility compiled.Visibility) int { @@ -66,49 +76,40 @@ func (cs *ConstraintSystem) toR1CS(curveID ecc.ID) (CompiledConstraintSystem, er } // we need to offset the ids in the hints - for i := 0; i < len(res.Hints); i++ { - res.Hints[i].WireID = shiftVID(res.Hints[i].WireID, compiled.Internal) - for j := 0; j < len(res.Hints[i].Inputs); j++ { - offsetIDs(res.Hints[i].Inputs[j]) + for vID, hint := range cs.mHints { + k := shiftVID(vID, compiled.Internal) + inputs := make([]compiled.LinearExpression, len(hint.Inputs)) + copy(inputs, hint.Inputs) + for j := 0; j < len(inputs); j++ { + offsetIDs(inputs[j]) } - + res.MHints[k] = compiled.Hint{ID: hint.ID, Inputs: inputs} } - // we need to offset the ids in logs + // we need to offset the ids in logs & debugInfo for i := 0; i < len(cs.logs); i++ { - entry := compiled.LogEntry{ - Format: cs.logs[i].format, - } - for j := 0; j < len(cs.logs[i].toResolve); j++ { - _, vID, visibility := cs.logs[i].toResolve[j].Unpack() - entry.ToResolve = append(entry.ToResolve, shiftVID(vID, visibility)) + res.Logs[i] = compiled.LogEntry{ + Format: cs.logs[i].Format, + ToResolve: make([]compiled.Term, len(cs.logs[i].ToResolve)), } + copy(res.Logs[i].ToResolve, cs.logs[i].ToResolve) - res.Logs[i] = entry - } - - // offset ids in the debugInfoComputation - for i := 0; i < len(cs.debugInfoComputation); i++ { - entry := compiled.LogEntry{ - Format: cs.debugInfoComputation[i].format, - } - for j := 0; j < len(cs.debugInfoComputation[i].toResolve); j++ { - _, vID, visibility := cs.debugInfoComputation[i].toResolve[j].Unpack() - entry.ToResolve = append(entry.ToResolve, shiftVID(vID, visibility)) + for j := 0; j < len(res.Logs[i].ToResolve); j++ { + _, vID, visibility := res.Logs[i].ToResolve[j].Unpack() + res.Logs[i].ToResolve[j].SetVariableID(shiftVID(vID, visibility)) } - - res.DebugInfoComputation[i] = entry } - for i := 0; i < len(cs.debugInfoAssertion); i++ { - entry := compiled.LogEntry{ - Format: cs.debugInfoAssertion[i].format, - } - for j := 0; j < len(cs.debugInfoAssertion[i].toResolve); j++ { - _, vID, visibility := cs.debugInfoAssertion[i].toResolve[j].Unpack() - entry.ToResolve = append(entry.ToResolve, shiftVID(vID, visibility)) + for i := 0; i < len(cs.debugInfo); i++ { + res.DebugInfo[i] = compiled.LogEntry{ + Format: cs.debugInfo[i].Format, + ToResolve: make([]compiled.Term, len(cs.debugInfo[i].ToResolve)), } + copy(res.DebugInfo[i].ToResolve, cs.debugInfo[i].ToResolve) - res.DebugInfoComputation[i+len(cs.debugInfoComputation)] = entry + for j := 0; j < len(res.DebugInfo[i].ToResolve); j++ { + _, vID, visibility := res.DebugInfo[i].ToResolve[j].Unpack() + res.DebugInfo[i].ToResolve[j].SetVariableID(shiftVID(vID, visibility)) + } } switch curveID { diff --git a/frontend/cs_to_r1cs_sparse.go b/frontend/cs_to_r1cs_sparse.go index 1658b2a709..44c9a809b5 100644 --- a/frontend/cs_to_r1cs_sparse.go +++ b/frontend/cs_to_r1cs_sparse.go @@ -48,46 +48,58 @@ type sparseR1CS struct { // and guarantee that the solver will encounter at most one unsolved wire // per SparseR1C solvedVariables []bool + + currentR1CDebugID int // mark the current R1C debugID } +var bOne = new(big.Int).SetInt64(1) + func (cs *ConstraintSystem) toSparseR1CS(curveID ecc.ID) (CompiledConstraintSystem, error) { res := sparseR1CS{ ConstraintSystem: cs, ccs: compiled.SparseR1CS{ - NbPublicVariables: len(cs.public.variables) - 1, // the ONE_WIRE is discarded as it is not used in PLONK - NbSecretVariables: len(cs.secret.variables), - NbInternalVariables: len(cs.internal.variables), - Constraints: make([]compiled.SparseR1C, 0, len(cs.constraints)+len(cs.assertions)), - Logs: make([]compiled.LogEntry, len(cs.logs)), - Hints: make([]compiled.Hint, len(cs.hints)), + CS: compiled.CS{ + NbInternalVariables: len(cs.internal.variables), + NbPublicVariables: len(cs.public.variables) - 1, // the ONE_WIRE is discarded in PlonK + NbSecretVariables: len(cs.secret.variables), + DebugInfo: make([]compiled.LogEntry, len(cs.debugInfo)), + Logs: make([]compiled.LogEntry, len(cs.logs)), + MDebug: make(map[int]int), + MHints: make(map[int]compiled.Hint), + }, + Constraints: make([]compiled.SparseR1C, 0, len(cs.constraints)), }, solvedVariables: make([]bool, len(cs.internal.variables), len(cs.internal.variables)*2), scsInternalVariables: len(cs.internal.variables), + currentR1CDebugID: -1, } - // note: verbose, but we offset the IDs of the wires where they appear, that is, - // in the logs, debug info, constraints and hints - // since we don't use pointers but Terms (uint64), we need to potentially offset - // the same wireID multiple times. - copy(res.ccs.Hints, cs.hints) + // logs, debugInfo and hints are copied, the only thing that will change + // is that ID of the wires will be offseted to take into account the final wire vector ordering + // that is: public wires | secret wires | internal wires - // TODO @gbotrel we may not want to do that as it may hide some bugs - // if there is a R1C with several unsolved wires, wether they are hint wires or not - // will be problematic at solving time - for i := 0; i < len(cs.hints); i++ { - res.solvedVariables[cs.hints[i].WireID] = true + // we mark hint wires are solved + // each R1C from the frontend.ConstraintSystem is allowed to have at most one unsolved wire + // excluding hints. We mark hint wires as "solved" to ensure spliting R1C to SparseR1C + // won't create invalid SparseR1C constraint with more than one wire to solve for the solver + for vID := range cs.mHints { + res.solvedVariables[vID] = true } // convert the R1C to SparseR1C // in particular, all linear expressions that appear in the R1C // will be split in multiple constraints in the SparseR1C for i := 0; i < len(cs.constraints); i++ { + // we set currentR1CDebugID to the debugInfo ID corresponding to the R1C we're processing + // if present. All constraints created throuh addConstraint will add a new mapping + if dID, ok := cs.mDebug[i]; ok { + res.currentR1CDebugID = dID + } else { + res.currentR1CDebugID = -1 + } res.r1cToSparseR1C(cs.constraints[i]) } - for i := 0; i < len(cs.assertions); i++ { - res.r1cToSparseR1C(cs.assertions[i]) - } // shift variable ID // we want publicWires | privateWires | internalWires @@ -135,28 +147,41 @@ func (cs *ConstraintSystem) toSparseR1CS(curveID ecc.ID) (CompiledConstraintSyst offsetTermID(&r1c.M[1]) } - // offset IDs in the logs + // we need to offset the ids in logs & debugInfo for i := 0; i < len(cs.logs); i++ { - entry := compiled.LogEntry{ - Format: cs.logs[i].format, - ToResolve: make([]int, len(cs.logs[i].toResolve)), + res.ccs.Logs[i] = compiled.LogEntry{ + Format: cs.logs[i].Format, + ToResolve: make([]compiled.Term, len(cs.logs[i].ToResolve)), + } + copy(res.ccs.Logs[i].ToResolve, cs.logs[i].ToResolve) + + for j := 0; j < len(res.ccs.Logs[i].ToResolve); j++ { + offsetTermID(&res.ccs.Logs[i].ToResolve[j]) + } + } + for i := 0; i < len(cs.debugInfo); i++ { + res.ccs.DebugInfo[i] = compiled.LogEntry{ + Format: cs.debugInfo[i].Format, + ToResolve: make([]compiled.Term, len(cs.debugInfo[i].ToResolve)), } - for j := 0; j < len(cs.logs[i].toResolve); j++ { - _, cID, cVisibility := cs.logs[i].toResolve[j].Unpack() - entry.ToResolve[j] = shiftVID(cID, cVisibility) + copy(res.ccs.DebugInfo[i].ToResolve, cs.debugInfo[i].ToResolve) + + for j := 0; j < len(res.ccs.DebugInfo[i].ToResolve); j++ { + offsetTermID(&res.ccs.DebugInfo[i].ToResolve[j]) } - res.ccs.Logs[i] = entry } // we need to offset the ids in the hints - for i := 0; i < len(res.ccs.Hints); i++ { - res.ccs.Hints[i].WireID = shiftVID(res.ccs.Hints[i].WireID, compiled.Internal) - for j := 0; j < len(res.ccs.Hints[i].Inputs); j++ { - l := res.ccs.Hints[i].Inputs[j] - for k := 0; k < len(l); k++ { - offsetTermID(&l[k]) + for vID, hint := range cs.mHints { + k := shiftVID(vID, compiled.Internal) + inputs := make([]compiled.LinearExpression, len(hint.Inputs)) + copy(inputs, hint.Inputs) + for j := 0; j < len(inputs); j++ { + for k := 0; k < len(inputs[j]); k++ { + offsetTermID(&inputs[j][k]) } } + res.ccs.MHints[k] = compiled.Hint{ID: hint.ID, Inputs: inputs} } // update number of internal variables with new wires created @@ -295,6 +320,9 @@ func (scs *sparseR1CS) addConstraint(c compiled.SparseR1C) { if c.M[1] == 0 { c.M[1].SetVariableID(c.R.VariableID()) } + if scs.currentR1CDebugID != -1 { + scs.ccs.MDebug[len(scs.ccs.Constraints)] = scs.currentR1CDebugID + } scs.ccs.Constraints = append(scs.ccs.Constraints, c) } @@ -391,6 +419,7 @@ func (scs *sparseR1CS) split(a compiled.Term, l compiled.LinearExpression) compi // r1cToSparseR1C splits a r1c constraint func (scs *sparseR1CS) r1cToSparseR1C(r1c compiled.R1C) { + // find if the variable to solve is in the left, right, or o linear expression lro, idCS := findUnsolvedVariable(r1c, scs.solvedVariables) if lro == -1 { diff --git a/frontend/frontend.go b/frontend/frontend.go index 5ab626e6e2..678af9f067 100644 --- a/frontend/frontend.go +++ b/frontend/frontend.go @@ -41,11 +41,6 @@ func Compile(curveID ecc.ID, zkpID backend.ID, circuit Circuit, initialCapacity return nil, err } - // sanity checks - if err := cs.sanityCheck(); err != nil { - return nil, err - } - switch zkpID { case backend.GROTH16: ccs, err = cs.toR1CS(curveID) @@ -61,50 +56,6 @@ func Compile(curveID ecc.ID, zkpID backend.ID, circuit Circuit, initialCapacity return } -// sanityCheck ensures: -// * all constraints must have at most one unsolved wire, excluding hint wires -func (cs *ConstraintSystem) sanityCheck() error { - - solved := make([]bool, len(cs.internal.variables)) - for i := 0; i < len(cs.hints); i++ { - solved[cs.hints[i].WireID] = true - } - - countUnsolved := func(r1c compiled.R1C) int { - c := 0 - for i := 0; i < len(r1c.L); i++ { - _, vID, visibility := r1c.L[i].Unpack() - if visibility == compiled.Internal && !solved[vID] { - c++ - solved[vID] = true - } - } - for i := 0; i < len(r1c.R); i++ { - _, vID, visibility := r1c.R[i].Unpack() - if visibility == compiled.Internal && !solved[vID] { - c++ - solved[vID] = true - } - } - for i := 0; i < len(r1c.O); i++ { - _, vID, visibility := r1c.O[i].Unpack() - if visibility == compiled.Internal && !solved[vID] { - c++ - solved[vID] = true - } - } - return c - } - - for _, r1c := range cs.constraints { - if countUnsolved(r1c) > 1 { - return errors.New("constraint system has invalid constraints with multiple unsolved wire") - } - } - - return nil -} - // buildCS builds the constraint system. It bootstraps the inputs // allocations by parsing the circuit's underlying structure, then // it builds the constraint system using the Define method. @@ -113,6 +64,8 @@ func buildCS(curveID ecc.ID, circuit Circuit, initialCapacity ...int) (cs Constr defer func() { if r := recover(); r != nil { err = fmt.Errorf("%v", r) + // TODO @gbotrel with debug buiild tag + // fmt.Println(string(debug.Stack())) } }() // instantiate our constraint system diff --git a/frontend/fuzz.go b/frontend/fuzz.go index ab498886e2..a312bf5d44 100644 --- a/frontend/fuzz.go +++ b/frontend/fuzz.go @@ -22,13 +22,13 @@ func Fuzz(data []byte) int { curves := []ecc.ID{ecc.BN254, ecc.BLS12_381} for _, curveID := range curves { - _, _ = CsFuzzed(data, curveID) + _ = CsFuzzed(data, curveID) } return 1 } -func CsFuzzed(data []byte, curveID ecc.ID) (ccs CompiledConstraintSystem, nbAssertions int) { +func CsFuzzed(data []byte, curveID ecc.ID) (ccs CompiledConstraintSystem) { cs := newConstraintSystem() reader := bytes.NewReader(data) @@ -126,7 +126,7 @@ compile: if err != nil { panic(fmt.Sprintf("compiling (curve %s) failed: %v", curveID.String(), err)) } - return ccs, len(cs.assertions) + return ccs } func (cs *ConstraintSystem) shuffleVariables(seed int64, withConstant bool) []interface{} { diff --git a/go.mod b/go.mod index 715a08287d..8bd8eaa3b6 100644 --- a/go.mod +++ b/go.mod @@ -4,12 +4,11 @@ go 1.16 require ( github.com/consensys/bavard v0.1.8-0.20210806153619-fcffe4ffd871 - github.com/consensys/gnark-crypto v0.5.1-0.20210907174324-9721833081d7 + github.com/consensys/gnark-crypto v0.5.1-0.20210917183421-cb36b2c871c0 github.com/davecgh/go-spew v1.1.1 // indirect github.com/fxamacker/cbor/v2 v2.2.0 github.com/kr/pretty v0.2.0 // indirect github.com/leanovate/gopter v0.2.9 - github.com/pkg/profile v1.6.0 // indirect github.com/stretchr/testify v1.7.0 gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect diff --git a/go.sum b/go.sum index cdee21cff1..8c70a887cf 100644 --- a/go.sum +++ b/go.sum @@ -1,13 +1,7 @@ github.com/consensys/bavard v0.1.8-0.20210806153619-fcffe4ffd871 h1:gfdz2r/E4uQhD8jDUv2SaWQClfzFuZioHGAzPw7oZng= github.com/consensys/bavard v0.1.8-0.20210806153619-fcffe4ffd871/go.mod h1:Bpd0/3mZuaj6Sj+PqrmIquiOKy397AKGThQPaGzNXAQ= -github.com/consensys/gnark-crypto v0.4.1-0.20210818174051-018b86471fca h1:YuKivJirttUz/FNlAp1dwIiJiYyPOoyno2CoRlfqMNs= -github.com/consensys/gnark-crypto v0.4.1-0.20210818174051-018b86471fca/go.mod h1:5u+nS08qZhHtugNg17dAnCGqbnRCJ6XSdPj0LyFvAOM= -github.com/consensys/gnark-crypto v0.5.0 h1:c+1SOpCPKmw5lKth/hIoRgcw23KSgWnNR/b5M+JRC3k= -github.com/consensys/gnark-crypto v0.5.0/go.mod h1:wAZ9dsKCDVTSIy2KVTik+ZF16GUX9qp96mxFBDl9iAQ= -github.com/consensys/gnark-crypto v0.5.1-0.20210907173531-0ae8b5c38618 h1:vnrIRUFj8afz/QlnRlD+AVLkgkyYj6JD6MNLK6R7dcg= -github.com/consensys/gnark-crypto v0.5.1-0.20210907173531-0ae8b5c38618/go.mod h1:wAZ9dsKCDVTSIy2KVTik+ZF16GUX9qp96mxFBDl9iAQ= -github.com/consensys/gnark-crypto v0.5.1-0.20210907174324-9721833081d7 h1:2k7ImGxDTTY2OpiKjnFDfqc/ir8O54qCwUTnobfDbkM= -github.com/consensys/gnark-crypto v0.5.1-0.20210907174324-9721833081d7/go.mod h1:wAZ9dsKCDVTSIy2KVTik+ZF16GUX9qp96mxFBDl9iAQ= +github.com/consensys/gnark-crypto v0.5.1-0.20210917183421-cb36b2c871c0 h1:ODfAG0P/XaGvh1JNZM9tzL2MKVaqFdE7FeATcrdrHB0= +github.com/consensys/gnark-crypto v0.5.1-0.20210917183421-cb36b2c871c0/go.mod h1:wAZ9dsKCDVTSIy2KVTik+ZF16GUX9qp96mxFBDl9iAQ= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -20,7 +14,6 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/leanovate/gopter v0.2.9 h1:fQjYxZaynp97ozCzfOyOuAGOU4aU/z37zf/tOujFk7c= github.com/leanovate/gopter v0.2.9/go.mod h1:U2L/78B+KVFIx2VmW6onHJQzXtFb+p5y3y2Sh+Jxxv8= -github.com/pkg/profile v1.6.0/go.mod h1:qBsxPvzyUincmltOk6iyRVxHYg4adc0OFOv72ZdLa18= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/integration_test.go b/integration_test.go index 777f1f4b2c..36a09b0500 100644 --- a/integration_test.go +++ b/integration_test.go @@ -68,10 +68,10 @@ func TestIntegrationAPI(t *testing.T) { pk, vk, err := groth16.Setup(ccs) assert.NoError(err) - correctProof, err := groth16.Prove(ccs, pk, circuit.Good, nil) + correctProof, err := groth16.Prove(ccs, pk, circuit.Good) assert.NoError(err) - wrongProof, err := groth16.Prove(ccs, pk, circuit.Bad, nil, true) + wrongProof, err := groth16.Prove(ccs, pk, circuit.Bad, backend.IgnoreSolverError) assert.NoError(err) assert.NoError(groth16.Verify(correctProof, vk, circuit.Good)) @@ -84,7 +84,7 @@ func TestIntegrationAPI(t *testing.T) { _, err := witness.WriteFullTo(&buf, curve, circuit.Good) assert.NoError(err) - correctProof, err := groth16.ReadAndProve(ccs, pk, &buf, nil) + correctProof, err := groth16.ReadAndProve(ccs, pk, &buf) assert.NoError(err) buf.Reset() @@ -114,10 +114,10 @@ func TestIntegrationAPI(t *testing.T) { pk, vk, err := plonk.Setup(ccs, srs) assert.NoError(err) - correctProof, err := plonk.Prove(ccs, pk, circuit.Good, nil) + correctProof, err := plonk.Prove(ccs, pk, circuit.Good) assert.NoError(err) - wrongProof, err := plonk.Prove(ccs, pk, circuit.Bad, nil, true) + wrongProof, err := plonk.Prove(ccs, pk, circuit.Bad, backend.IgnoreSolverError) assert.NoError(err) assert.NoError(plonk.Verify(correctProof, vk, circuit.Good)) @@ -130,7 +130,7 @@ func TestIntegrationAPI(t *testing.T) { _, err := witness.WriteFullTo(&buf, curve, circuit.Good) assert.NoError(err) - correctProof, err := plonk.ReadAndProve(ccs, pk, &buf, nil) + correctProof, err := plonk.ReadAndProve(ccs, pk, &buf) assert.NoError(err) buf.Reset() diff --git a/internal/backend/bls12-377/cs/r1cs.go b/internal/backend/bls12-377/cs/r1cs.go index ef03182f09..e9b60cec33 100644 --- a/internal/backend/bls12-377/cs/r1cs.go +++ b/internal/backend/bls12-377/cs/r1cs.go @@ -21,12 +21,11 @@ import ( "fmt" "io" "math/big" - "os" "strings" "github.com/fxamacker/cbor/v2" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" @@ -40,8 +39,6 @@ import ( type R1CS struct { compiled.R1CS Coefficients []fr.Element // R1C coefficients indexes point here - loggerOut io.Writer - mHints map[int]int // correspondance between hint wire ID and hint data struct } // NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values @@ -49,14 +46,11 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { r := R1CS{ R1CS: cs, Coefficients: make([]fr.Element, len(coefficients)), - loggerOut: os.Stdout, } for i := 0; i < len(coefficients); i++ { r.Coefficients[i].SetBigInt(&coefficients[i]) } - r.initHints() - return &r } @@ -65,10 +59,10 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { // a, b, c vectors: ab-c = hz // witness = [publicWires | secretWires] (without the ONE_WIRE !) // returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Function) ([]fr.Element, error) { +func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverOption) ([]fr.Element, error) { nbWires := cs.NbPublicVariables + cs.NbSecretVariables + cs.NbInternalVariables - solution, err := newSolution(nbWires, hintFunctions, cs.Coefficients) + solution, err := newSolution(nbWires, opt.HintFunctions, cs.Coefficients) if err != nil { return make([]fr.Element, nbWires), err } @@ -78,8 +72,8 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi } // compute the wires and the a, b, c polynomials - if len(a) != int(cs.NbConstraints) || len(b) != int(cs.NbConstraints) || len(c) != int(cs.NbConstraints) { - return solution.values, errors.New("invalid input size: len(a, b, c) == cs.NbConstraints") + if len(a) != len(cs.Constraints) || len(b) != len(cs.Constraints) || len(c) != len(cs.Constraints) { + return solution.values, errors.New("invalid input size: len(a, b, c) == len(Constraints)") } solution.solved[0] = true // ONE_WIRE @@ -95,16 +89,11 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // now that we know all inputs are set, defer log printing once all solution.values are computed // (or sooner, if a constraint is not satisfied) - defer solution.printLogs(cs.loggerOut, cs.Logs) + defer solution.printLogs(opt.LoggerOut, cs.Logs) // check if there is an inconsistant constraint var check fr.Element - // TODO @gbotrel clean this - // this variable is used to navigate in the debugInfoComputation slice. - // It is incremented by one each time a division happens for solving a constraint. - var debugInfoComputationOffset uint - // for each constraint // we are guaranteed that each R1C contains at most one unsolved wire // first we solve the unsolved wire (if any) @@ -112,11 +101,10 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // if a[i] * b[i] != c[i]; it means the constraint is not satisfied for i := 0; i < len(cs.Constraints); i++ { // solve the constraint, this will compute the missing wire of the gate - offset, err := cs.solveConstraint(cs.Constraints[i], &solution) - if err != nil { + if err := cs.solveConstraint(cs.Constraints[i], &solution); err != nil { + // TODO should return debug info, if any. return solution.values, err } - debugInfoComputationOffset += offset // compute values for the R1C (ie value * coeff) a[i], b[i], c[i] = cs.instantiateR1C(cs.Constraints[i], &solution) @@ -124,9 +112,11 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // ensure a[i] * b[i] == c[i] check.Mul(&a[i], &b[i]) if !check.Equal(&c[i]) { - debugInfo := cs.DebugInfoComputation[debugInfoComputationOffset] - debugInfoStr := solution.logValue(debugInfo) - return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + } + return solution.values, ErrUnsatisfiedConstraint } } @@ -140,23 +130,14 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) error { - a := make([]fr.Element, cs.NbConstraints) - b := make([]fr.Element, cs.NbConstraints) - c := make([]fr.Element, cs.NbConstraints) - _, err := cs.Solve(witness, a, b, c, hintFunctions) +func (cs *R1CS) IsSolved(witness []fr.Element, opt backend.ProverOption) error { + a := make([]fr.Element, len(cs.Constraints)) + b := make([]fr.Element, len(cs.Constraints)) + c := make([]fr.Element, len(cs.Constraints)) + _, err := cs.Solve(witness, a, b, c, opt) return err } -func (cs *R1CS) initHints() { - // we may do that sooner to save time in the solver, but we want the serialized data structures to be - // deterministic, hence avoid maps in there. - cs.mHints = make(map[int]int, len(cs.Hints)) - for i := 0; i < len(cs.Hints); i++ { - cs.mHints[cs.Hints[i].WireID] = i - } -} - // mulByCoeff sets res = res * t.Coeff func (cs *R1CS) mulByCoeff(res *fr.Element, t compiled.Term) { cID := t.CoeffID() @@ -203,10 +184,7 @@ func (cs *R1CS) instantiateR1C(r compiled.R1C, solution *solution) (a, b, c fr.E // It returns the 1 if the the position to solve is in the quadratic part (it // means that there is a division and serves to navigate in the log info for the // computational constraints), and 0 otherwise. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error) { - - // value to return: 1 if the wire to solve is in the quadratic term, 0 otherwise - var offset uint +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated @@ -226,9 +204,8 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error } // first we check if this is a hint wire - if hID, ok := cs.mHints[vID]; ok { - // TODO handle error - return solution.solveHint(cs.Hints[hID], vID) + if hint, ok := cs.MHints[vID]; ok { + return solution.solveWithHint(vID, hint) } if loc != 0 { @@ -241,19 +218,19 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error for _, t := range r.L { if err := processTerm(t, &a, 1); err != nil { - return 0, err + return err } } for _, t := range r.R { if err := processTerm(t, &b, 2); err != nil { - return 0, err + return err } } for _, t := range r.O { if err := processTerm(t, &c, 3); err != nil { - return 0, err + return err } } @@ -261,7 +238,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return 0, nil + return nil } // we compute the wire value and instantiate it @@ -276,14 +253,12 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error wire.Div(&c, &b). Sub(&wire, &a) cs.mulByCoeff(&wire, termToCompute) - offset = 1 } case 2: if !a.IsZero() { wire.Div(&c, &a). Sub(&wire, &b) cs.mulByCoeff(&wire, termToCompute) - offset = 1 } case 3: wire.Mul(&a, &b). @@ -293,7 +268,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error solution.set(vID, wire) - return offset, nil + return nil } // TODO @gbotrel clean logs and html @@ -309,15 +284,7 @@ func (cs *R1CS) ToHTML(w io.Writer) error { return err } - type data struct { - *R1CS - MHints map[int]int - } - d := data{ - cs, - cs.mHints, - } - return t.Execute(w, &d) + return t.Execute(w, cs) } func add(a, b int) int { @@ -328,10 +295,10 @@ func sub(a, b int) int { return a - b } -func toHTML(l compiled.LinearExpression, coeffs []fr.Element, mHints map[int]int) string { +func toHTML(l compiled.LinearExpression, coeffs []fr.Element, MHints map[int]int) string { var sbb strings.Builder for i := 0; i < len(l); i++ { - termToHTML(l[i], &sbb, coeffs, mHints, false) + termToHTML(l[i], &sbb, coeffs, MHints, false) if i+1 < len(l) { sbb.WriteString(" + ") } @@ -339,7 +306,7 @@ func toHTML(l compiled.LinearExpression, coeffs []fr.Element, mHints map[int]int return sbb.String() } -func termToHTML(t compiled.Term, sbb *strings.Builder, coeffs []fr.Element, mHints map[int]int, offset bool) { +func termToHTML(t compiled.Term, sbb *strings.Builder, coeffs []fr.Element, MHints map[int]int, offset bool) { tID := t.CoeffID() if tID == compiled.CoeffIdOne { // do nothing, just print the variable @@ -360,7 +327,7 @@ func termToHTML(t compiled.Term, sbb *strings.Builder, coeffs []fr.Element, mHin switch t.VariableVisibility() { case compiled.Internal: class = "internal" - if _, ok := mHints[vID]; ok { + if _, ok := MHints[vID]; ok { class = "hint" } case compiled.Public: @@ -396,24 +363,18 @@ func (cs *R1CS) FrSize() int { return fr.Limbs * 8 } -// SetLoggerOutput replace existing logger output with provided one -// default uses os.Stdout -// if nil is provided, logs are not printed -func (cs *R1CS) SetLoggerOutput(w io.Writer) { - cs.loggerOut = w -} - // WriteTo encodes R1CS into provided io.Writer using cbor func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - encoder := cbor.NewEncoder(&_w) - - // encode our object - if err := encoder.Encode(cs); err != nil { - return _w.N, err + enc, err := cbor.CoreDetEncOptions().EncMode() + if err != nil { + return 0, err } + encoder := enc.NewEncoder(&_w) - return _w.N, nil + // encode our object + err = encoder.Encode(cs) + return _w.N, err } // ReadFrom attempts to decode R1CS from io.Reader using cbor @@ -427,8 +388,5 @@ func (cs *R1CS) ReadFrom(r io.Reader) (int64, error) { return int64(decoder.NumBytesRead()), err } - // init the hint map - cs.initHints() - return int64(decoder.NumBytesRead()), nil } diff --git a/internal/backend/bls12-377/cs/r1cs_sparse.go b/internal/backend/bls12-377/cs/r1cs_sparse.go index e2e5435a99..22421afdde 100644 --- a/internal/backend/bls12-377/cs/r1cs_sparse.go +++ b/internal/backend/bls12-377/cs/r1cs_sparse.go @@ -26,7 +26,7 @@ import ( "strings" "text/template" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" @@ -37,9 +37,8 @@ import ( type SparseR1CS struct { compiled.SparseR1CS - // Coefficients in the constraints - Coefficients []fr.Element // list of unique coefficients. - mHints map[int]int // correspondance between hint wire ID and hint data struct + Coefficients []fr.Element // coefficients in the constraints + loggerOut io.Writer } // NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values @@ -47,13 +46,12 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS cs := SparseR1CS{ SparseR1CS: ccs, Coefficients: make([]fr.Element, len(coefficients)), + loggerOut: os.Stdout, } for i := 0; i < len(coefficients); i++ { cs.Coefficients[i].SetBigInt(&coefficients[i]) } - cs.initHints() - return &cs } @@ -61,7 +59,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS // solution.values = [publicInputs | secretInputs | internalVariables ] // witness: contains the input variables // it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) (values []fr.Element, err error) { +func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverOption) ([]fr.Element, error) { // set the slices holding the solution.values and monitoring which variables have been solved nbVariables := cs.NbInternalVariables + cs.NbSecretVariables + cs.NbPublicVariables @@ -78,7 +76,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) } // keep track of wire that have a value - solution, err := newSolution(nbVariables, hintFunctions, cs.Coefficients) + solution, err := newSolution(nbVariables, opt.HintFunctions, cs.Coefficients) if err != nil { return solution.values, err } @@ -94,8 +92,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) solution.nbSolved += len(witness) // defer log printing once all solution.values are computed - // TODO @gbotrel replace stdout by writer set by user, same as in R1CS - defer solution.printLogs(os.Stdout, cs.Logs) + defer solution.printLogs(opt.LoggerOut, cs.Logs) // batch invert the coefficients to avoid many divisions in the solver coefficientsNegInv := fr.BatchInvert(cs.Coefficients) @@ -109,7 +106,11 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) return solution.values, fmt.Errorf("constraint %d: %w", i, err) } if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil { - return solution.values, fmt.Errorf("constraint %d: %w", i, err) + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + } + return solution.values, ErrUnsatisfiedConstraint } } @@ -131,8 +132,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { // check if it's a hint - if hID, ok := cs.mHints[lID]; ok { - if err := solution.solveHint(cs.Hints[hID], lID); err != nil { + if hint, ok := cs.MHints[lID]; ok { + if err := solution.solveWithHint(lID, hint); err != nil { return -1, err } } else { @@ -143,8 +144,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { // check if it's a hint - if hID, ok := cs.mHints[rID]; ok { - if err := solution.solveHint(cs.Hints[hID], rID); err != nil { + if hint, ok := cs.MHints[rID]; ok { + if err := solution.solveWithHint(rID, hint); err != nil { return -1, err } } else { @@ -154,8 +155,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.O.CoeffID() != 0) && !solution.solved[oID] { // check if it's a hint - if hID, ok := cs.mHints[oID]; ok { - if err := solution.solveHint(cs.Hints[hID], oID); err != nil { + if hint, ok := cs.MHints[oID]; ok { + if err := solution.solveWithHint(oID, hint); err != nil { return -1, err } } else { @@ -223,8 +224,8 @@ func (cs *SparseR1CS) solveConstraint(c compiled.SparseR1C, solution *solution, // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps r1cs.Solve() and allocates r1cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) error { - _, err := cs.Solve(witness, hintFunctions) +func (cs *SparseR1CS) IsSolved(witness []fr.Element, opt backend.ProverOption) error { + _, err := cs.Solve(witness, opt) return err } @@ -265,21 +266,12 @@ func (cs *SparseR1CS) ToHTML(w io.Writer) error { return err } - type data struct { - *SparseR1CS - MHints map[int]int - } - d := data{ - cs, - cs.mHints, - } - - return t.Execute(w, &d) + return t.Execute(w, cs) } -func toHTMLTerm(t compiled.Term, coeffs []fr.Element, mHints map[int]int) string { +func toHTMLTerm(t compiled.Term, coeffs []fr.Element, MHints map[int]int) string { var sbb strings.Builder - termToHTML(t, &sbb, coeffs, mHints, true) + termToHTML(t, &sbb, coeffs, MHints, true) return sbb.String() } @@ -295,15 +287,6 @@ func toHTMLCoeff(cID int, coeffs []fr.Element) string { return sbb.String() } -func (cs *SparseR1CS) initHints() { - // we may do that sooner to save time in the solver, but we want the serialized data structures to be - // deterministic, hence avoid maps in there. - cs.mHints = make(map[int]int, len(cs.Hints)) - for i := 0; i < len(cs.Hints); i++ { - cs.mHints[cs.Hints[i].WireID] = i - } -} - // FrSize return fr.Limbs * 8, size in byte of a fr element func (cs *SparseR1CS) FrSize() int { return fr.Limbs * 8 @@ -322,10 +305,14 @@ func (cs *SparseR1CS) CurveID() ecc.ID { // WriteTo encodes SparseR1CS into provided io.Writer using cbor func (cs *SparseR1CS) WriteTo(w io.Writer) (int64, error) { _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - encoder := cbor.NewEncoder(&_w) + enc, err := cbor.CoreDetEncOptions().EncMode() + if err != nil { + return 0, err + } + encoder := enc.NewEncoder(&_w) // encode our object - err := encoder.Encode(cs) + err = encoder.Encode(cs) return _w.N, err } @@ -337,6 +324,12 @@ func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { } decoder := dm.NewDecoder(r) err = decoder.Decode(cs) - cs.initHints() return int64(decoder.NumBytesRead()), err } + +// SetLoggerOutput replace existing logger output with provided one +// default uses os.Stdout +// if nil is provided, logs are not printed +func (cs *SparseR1CS) SetLoggerOutput(w io.Writer) { + cs.loggerOut = w +} diff --git a/internal/backend/bls12-377/cs/r1cs_test.go b/internal/backend/bls12-377/cs/r1cs_test.go index f4149a87e9..1409e76e89 100644 --- a/internal/backend/bls12-377/cs/r1cs_test.go +++ b/internal/backend/bls12-377/cs/r1cs_test.go @@ -48,9 +48,6 @@ func TestSerialization(t *testing.T) { t.Fatal(err) } - // no need to serialize. - r1cs.SetLoggerOutput(nil) - r1cs2.SetLoggerOutput(nil) { buffer.Reset() t.Log(name) diff --git a/internal/backend/bls12-377/cs/cs.go b/internal/backend/bls12-377/cs/solution.go similarity index 75% rename from internal/backend/bls12-377/cs/cs.go rename to internal/backend/bls12-377/cs/solution.go index 6a018df255..6a24bb28cb 100644 --- a/internal/backend/bls12-377/cs/cs.go +++ b/internal/backend/bls12-377/cs/solution.go @@ -105,7 +105,7 @@ func (s *solution) computeTerm(t compiled.Term) fr.Element { } // solveHint compute solution.values[vID] using provided solver hint -func (s *solution) solveHint(h compiled.Hint, vID int) error { +func (s *solution) solveWithHint(vID int, h compiled.Hint) error { // compute values for all inputs. inputs := make([]fr.Element, len(h.Inputs)) @@ -149,12 +149,65 @@ func (s *solution) printLogs(w io.Writer, logs []compiled.LogEntry) { } } +const unsolvedVariable = "" + func (s *solution) logValue(log compiled.LogEntry) string { var toResolve []interface{} + var ( + isEval bool + eval fr.Element + missingValue bool + ) for j := 0; j < len(log.ToResolve); j++ { - vID := log.ToResolve[j] + if log.ToResolve[j] == compiled.TermDelimitor { + // this is a special case where we want to evaluate the following terms until the next delimitor. + if !isEval { + isEval = true + missingValue = false + eval.SetZero() + continue + } + isEval = false + if missingValue { + toResolve = append(toResolve, unsolvedVariable) + } else { + // we have to append our accumulator + toResolve = append(toResolve, eval.String()) + } + continue + } + cID, vID, visibility := log.ToResolve[j].Unpack() + + if isEval { + // we are evaluating + if visibility == compiled.Virtual { + // just add the constant + eval.Add(&eval, &s.coefficients[cID]) + continue + } + if !s.solved[vID] { + missingValue = true + continue + } + tv := s.computeTerm(log.ToResolve[j]) + eval.Add(&eval, &tv) + continue + } + + if visibility == compiled.Virtual { + // it's just a constant + if cID == compiled.CoeffIdMinusOne { + toResolve = append(toResolve, "-1") + } else { + toResolve = append(toResolve, s.coefficients[cID].String()) + } + continue + } + if !(cID == compiled.CoeffIdMinusOne || cID == compiled.CoeffIdOne) { + toResolve = append(toResolve, s.coefficients[cID].String()) + } if !s.solved[vID] { - toResolve = append(toResolve, "???") + toResolve = append(toResolve, unsolvedVariable) } else { toResolve = append(toResolve, s.values[vID].String()) } diff --git a/internal/backend/bls12-377/groth16/groth16_test.go b/internal/backend/bls12-377/groth16/groth16_test.go index a85b69f9df..124493eb33 100644 --- a/internal/backend/bls12-377/groth16/groth16_test.go +++ b/internal/backend/bls12-377/groth16/groth16_test.go @@ -130,7 +130,7 @@ func BenchmarkProver(b *testing.B) { b.ResetTimer() b.Run("prover", func(b *testing.B) { for i := 0; i < b.N; i++ { - _, _ = bls12_377groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + _, _ = bls12_377groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) } }) } @@ -151,7 +151,7 @@ func BenchmarkVerifier(b *testing.B) { var pk bls12_377groth16.ProvingKey var vk bls12_377groth16.VerifyingKey bls12_377groth16.Setup(r1cs.(*cs.R1CS), &pk, &vk) - proof, err := bls12_377groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + proof, err := bls12_377groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } @@ -175,7 +175,7 @@ func BenchmarkSerialization(b *testing.B) { var pk bls12_377groth16.ProvingKey var vk bls12_377groth16.VerifyingKey bls12_377groth16.Setup(r1cs.(*cs.R1CS), &pk, &vk) - proof, err := bls12_377groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + proof, err := bls12_377groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } diff --git a/internal/backend/bls12-377/groth16/prove.go b/internal/backend/bls12-377/groth16/prove.go index bb12874e91..2bde0d9440 100644 --- a/internal/backend/bls12-377/groth16/prove.go +++ b/internal/backend/bls12-377/groth16/prove.go @@ -27,7 +27,7 @@ import ( "fmt" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" bls12_377witness "github.com/consensys/gnark/internal/backend/bls12-377/witness" "github.com/consensys/gnark/internal/utils" "math/big" @@ -53,21 +53,19 @@ func (proof *Proof) CurveID() ecc.ID { } // Prove generates the proof of knoweldge of a r1cs with full witness (secret + public part). -// if force flag is set, Prove ignores R1CS solving error (ie invalid witness) and executes -// the FFTs and MultiExponentiations to compute an (invalid) Proof object -func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bls12_377witness.Witness, hintFunctions []hint.Function, force bool) (*Proof, error) { +func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bls12_377witness.Witness, opt backend.ProverOption) (*Proof, error) { if len(witness) != int(r1cs.NbPublicVariables-1+r1cs.NbSecretVariables) { return nil, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public - ONE_WIRE) + %d (secret)", len(witness), int(r1cs.NbPublicVariables-1+r1cs.NbSecretVariables), r1cs.NbPublicVariables, r1cs.NbSecretVariables) } // solve the R1CS and compute the a, b, c vectors - a := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) - b := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) - c := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) + a := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) + b := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) + c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) var wireValues []fr.Element var err error - if wireValues, err = r1cs.Solve(witness, a, b, c, hintFunctions); err != nil { - if !force { + if wireValues, err = r1cs.Solve(witness, a, b, c, opt); err != nil { + if !opt.Force { return nil, err } else { // we need to fill wireValues with random values else multi exps don't do much diff --git a/internal/backend/bls12-377/groth16/setup.go b/internal/backend/bls12-377/groth16/setup.go index 971ee42f8d..3e53f9b514 100644 --- a/internal/backend/bls12-377/groth16/setup.go +++ b/internal/backend/bls12-377/groth16/setup.go @@ -95,7 +95,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables // Setting group for fft - domain := fft.NewDomain(uint64(r1cs.NbConstraints), 1, true) + domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -412,7 +412,7 @@ func sampleToxicWaste() (toxicWaste, error) { func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.NbPublicVariables + r1cs.NbSecretVariables - nbConstraints := r1cs.NbConstraints + nbConstraints := len(r1cs.Constraints) // Setting group for fft domain := fft.NewDomain(uint64(nbConstraints), 1, true) diff --git a/internal/backend/bls12-377/plonk/plonk_test.go b/internal/backend/bls12-377/plonk/plonk_test.go index 9d1e8218c3..ae5e8fb660 100644 --- a/internal/backend/bls12-377/plonk/plonk_test.go +++ b/internal/backend/bls12-377/plonk/plonk_test.go @@ -135,7 +135,7 @@ func BenchmarkProver(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, err = bls12_377plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false) + _, err = bls12_377plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { b.Fatal(err) } @@ -160,7 +160,7 @@ func BenchmarkVerifier(b *testing.B) { b.Fatal(err) } - proof, err := bls12_377plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false) + proof, err := bls12_377plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } @@ -184,7 +184,7 @@ func BenchmarkSerialization(b *testing.B) { b.Fatal(err) } - proof, err := bls12_377plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false) + proof, err := bls12_377plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { b.Fatal(err) } diff --git a/internal/backend/bls12-377/plonk/prove.go b/internal/backend/bls12-377/plonk/prove.go index 2239a27ad6..c1efe643ef 100644 --- a/internal/backend/bls12-377/plonk/prove.go +++ b/internal/backend/bls12-377/plonk/prove.go @@ -38,7 +38,7 @@ import ( "github.com/consensys/gnark/internal/backend/bls12-377/cs" "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/utils" ) @@ -60,7 +60,7 @@ type Proof struct { } // Prove from the public data -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witness, hintFunctions []hint.Function, force bool) (*Proof, error) { +func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witness, opt backend.ProverOption) (*Proof, error) { // pick a hash function that will be used to derive the challenges hFunc := sha256.New() @@ -74,8 +74,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn // compute the constraint system solution var solution []fr.Element var err error - if solution, err = spr.Solve(fullWitness, hintFunctions); err != nil { - if !force { + if solution, err = spr.Solve(fullWitness, opt); err != nil { + if !opt.Force { return nil, err } else { // we need to fill solution with random values diff --git a/internal/backend/bls12-381/cs/r1cs.go b/internal/backend/bls12-381/cs/r1cs.go index b035e3e6cd..d4fc1a806d 100644 --- a/internal/backend/bls12-381/cs/r1cs.go +++ b/internal/backend/bls12-381/cs/r1cs.go @@ -21,12 +21,11 @@ import ( "fmt" "io" "math/big" - "os" "strings" "github.com/fxamacker/cbor/v2" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" @@ -40,8 +39,6 @@ import ( type R1CS struct { compiled.R1CS Coefficients []fr.Element // R1C coefficients indexes point here - loggerOut io.Writer - mHints map[int]int // correspondance between hint wire ID and hint data struct } // NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values @@ -49,14 +46,11 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { r := R1CS{ R1CS: cs, Coefficients: make([]fr.Element, len(coefficients)), - loggerOut: os.Stdout, } for i := 0; i < len(coefficients); i++ { r.Coefficients[i].SetBigInt(&coefficients[i]) } - r.initHints() - return &r } @@ -65,10 +59,10 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { // a, b, c vectors: ab-c = hz // witness = [publicWires | secretWires] (without the ONE_WIRE !) // returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Function) ([]fr.Element, error) { +func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverOption) ([]fr.Element, error) { nbWires := cs.NbPublicVariables + cs.NbSecretVariables + cs.NbInternalVariables - solution, err := newSolution(nbWires, hintFunctions, cs.Coefficients) + solution, err := newSolution(nbWires, opt.HintFunctions, cs.Coefficients) if err != nil { return make([]fr.Element, nbWires), err } @@ -78,8 +72,8 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi } // compute the wires and the a, b, c polynomials - if len(a) != int(cs.NbConstraints) || len(b) != int(cs.NbConstraints) || len(c) != int(cs.NbConstraints) { - return solution.values, errors.New("invalid input size: len(a, b, c) == cs.NbConstraints") + if len(a) != len(cs.Constraints) || len(b) != len(cs.Constraints) || len(c) != len(cs.Constraints) { + return solution.values, errors.New("invalid input size: len(a, b, c) == len(Constraints)") } solution.solved[0] = true // ONE_WIRE @@ -95,16 +89,11 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // now that we know all inputs are set, defer log printing once all solution.values are computed // (or sooner, if a constraint is not satisfied) - defer solution.printLogs(cs.loggerOut, cs.Logs) + defer solution.printLogs(opt.LoggerOut, cs.Logs) // check if there is an inconsistant constraint var check fr.Element - // TODO @gbotrel clean this - // this variable is used to navigate in the debugInfoComputation slice. - // It is incremented by one each time a division happens for solving a constraint. - var debugInfoComputationOffset uint - // for each constraint // we are guaranteed that each R1C contains at most one unsolved wire // first we solve the unsolved wire (if any) @@ -112,11 +101,10 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // if a[i] * b[i] != c[i]; it means the constraint is not satisfied for i := 0; i < len(cs.Constraints); i++ { // solve the constraint, this will compute the missing wire of the gate - offset, err := cs.solveConstraint(cs.Constraints[i], &solution) - if err != nil { + if err := cs.solveConstraint(cs.Constraints[i], &solution); err != nil { + // TODO should return debug info, if any. return solution.values, err } - debugInfoComputationOffset += offset // compute values for the R1C (ie value * coeff) a[i], b[i], c[i] = cs.instantiateR1C(cs.Constraints[i], &solution) @@ -124,9 +112,11 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // ensure a[i] * b[i] == c[i] check.Mul(&a[i], &b[i]) if !check.Equal(&c[i]) { - debugInfo := cs.DebugInfoComputation[debugInfoComputationOffset] - debugInfoStr := solution.logValue(debugInfo) - return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + } + return solution.values, ErrUnsatisfiedConstraint } } @@ -140,23 +130,14 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) error { - a := make([]fr.Element, cs.NbConstraints) - b := make([]fr.Element, cs.NbConstraints) - c := make([]fr.Element, cs.NbConstraints) - _, err := cs.Solve(witness, a, b, c, hintFunctions) +func (cs *R1CS) IsSolved(witness []fr.Element, opt backend.ProverOption) error { + a := make([]fr.Element, len(cs.Constraints)) + b := make([]fr.Element, len(cs.Constraints)) + c := make([]fr.Element, len(cs.Constraints)) + _, err := cs.Solve(witness, a, b, c, opt) return err } -func (cs *R1CS) initHints() { - // we may do that sooner to save time in the solver, but we want the serialized data structures to be - // deterministic, hence avoid maps in there. - cs.mHints = make(map[int]int, len(cs.Hints)) - for i := 0; i < len(cs.Hints); i++ { - cs.mHints[cs.Hints[i].WireID] = i - } -} - // mulByCoeff sets res = res * t.Coeff func (cs *R1CS) mulByCoeff(res *fr.Element, t compiled.Term) { cID := t.CoeffID() @@ -203,10 +184,7 @@ func (cs *R1CS) instantiateR1C(r compiled.R1C, solution *solution) (a, b, c fr.E // It returns the 1 if the the position to solve is in the quadratic part (it // means that there is a division and serves to navigate in the log info for the // computational constraints), and 0 otherwise. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error) { - - // value to return: 1 if the wire to solve is in the quadratic term, 0 otherwise - var offset uint +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated @@ -226,9 +204,8 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error } // first we check if this is a hint wire - if hID, ok := cs.mHints[vID]; ok { - // TODO handle error - return solution.solveHint(cs.Hints[hID], vID) + if hint, ok := cs.MHints[vID]; ok { + return solution.solveWithHint(vID, hint) } if loc != 0 { @@ -241,19 +218,19 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error for _, t := range r.L { if err := processTerm(t, &a, 1); err != nil { - return 0, err + return err } } for _, t := range r.R { if err := processTerm(t, &b, 2); err != nil { - return 0, err + return err } } for _, t := range r.O { if err := processTerm(t, &c, 3); err != nil { - return 0, err + return err } } @@ -261,7 +238,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return 0, nil + return nil } // we compute the wire value and instantiate it @@ -276,14 +253,12 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error wire.Div(&c, &b). Sub(&wire, &a) cs.mulByCoeff(&wire, termToCompute) - offset = 1 } case 2: if !a.IsZero() { wire.Div(&c, &a). Sub(&wire, &b) cs.mulByCoeff(&wire, termToCompute) - offset = 1 } case 3: wire.Mul(&a, &b). @@ -293,7 +268,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error solution.set(vID, wire) - return offset, nil + return nil } // TODO @gbotrel clean logs and html @@ -309,15 +284,7 @@ func (cs *R1CS) ToHTML(w io.Writer) error { return err } - type data struct { - *R1CS - MHints map[int]int - } - d := data{ - cs, - cs.mHints, - } - return t.Execute(w, &d) + return t.Execute(w, cs) } func add(a, b int) int { @@ -328,10 +295,10 @@ func sub(a, b int) int { return a - b } -func toHTML(l compiled.LinearExpression, coeffs []fr.Element, mHints map[int]int) string { +func toHTML(l compiled.LinearExpression, coeffs []fr.Element, MHints map[int]int) string { var sbb strings.Builder for i := 0; i < len(l); i++ { - termToHTML(l[i], &sbb, coeffs, mHints, false) + termToHTML(l[i], &sbb, coeffs, MHints, false) if i+1 < len(l) { sbb.WriteString(" + ") } @@ -339,7 +306,7 @@ func toHTML(l compiled.LinearExpression, coeffs []fr.Element, mHints map[int]int return sbb.String() } -func termToHTML(t compiled.Term, sbb *strings.Builder, coeffs []fr.Element, mHints map[int]int, offset bool) { +func termToHTML(t compiled.Term, sbb *strings.Builder, coeffs []fr.Element, MHints map[int]int, offset bool) { tID := t.CoeffID() if tID == compiled.CoeffIdOne { // do nothing, just print the variable @@ -360,7 +327,7 @@ func termToHTML(t compiled.Term, sbb *strings.Builder, coeffs []fr.Element, mHin switch t.VariableVisibility() { case compiled.Internal: class = "internal" - if _, ok := mHints[vID]; ok { + if _, ok := MHints[vID]; ok { class = "hint" } case compiled.Public: @@ -396,24 +363,18 @@ func (cs *R1CS) FrSize() int { return fr.Limbs * 8 } -// SetLoggerOutput replace existing logger output with provided one -// default uses os.Stdout -// if nil is provided, logs are not printed -func (cs *R1CS) SetLoggerOutput(w io.Writer) { - cs.loggerOut = w -} - // WriteTo encodes R1CS into provided io.Writer using cbor func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - encoder := cbor.NewEncoder(&_w) - - // encode our object - if err := encoder.Encode(cs); err != nil { - return _w.N, err + enc, err := cbor.CoreDetEncOptions().EncMode() + if err != nil { + return 0, err } + encoder := enc.NewEncoder(&_w) - return _w.N, nil + // encode our object + err = encoder.Encode(cs) + return _w.N, err } // ReadFrom attempts to decode R1CS from io.Reader using cbor @@ -427,8 +388,5 @@ func (cs *R1CS) ReadFrom(r io.Reader) (int64, error) { return int64(decoder.NumBytesRead()), err } - // init the hint map - cs.initHints() - return int64(decoder.NumBytesRead()), nil } diff --git a/internal/backend/bls12-381/cs/r1cs_sparse.go b/internal/backend/bls12-381/cs/r1cs_sparse.go index b72a43faf7..04909b7220 100644 --- a/internal/backend/bls12-381/cs/r1cs_sparse.go +++ b/internal/backend/bls12-381/cs/r1cs_sparse.go @@ -26,7 +26,7 @@ import ( "strings" "text/template" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" @@ -37,9 +37,8 @@ import ( type SparseR1CS struct { compiled.SparseR1CS - // Coefficients in the constraints - Coefficients []fr.Element // list of unique coefficients. - mHints map[int]int // correspondance between hint wire ID and hint data struct + Coefficients []fr.Element // coefficients in the constraints + loggerOut io.Writer } // NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values @@ -47,13 +46,12 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS cs := SparseR1CS{ SparseR1CS: ccs, Coefficients: make([]fr.Element, len(coefficients)), + loggerOut: os.Stdout, } for i := 0; i < len(coefficients); i++ { cs.Coefficients[i].SetBigInt(&coefficients[i]) } - cs.initHints() - return &cs } @@ -61,7 +59,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS // solution.values = [publicInputs | secretInputs | internalVariables ] // witness: contains the input variables // it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) (values []fr.Element, err error) { +func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverOption) ([]fr.Element, error) { // set the slices holding the solution.values and monitoring which variables have been solved nbVariables := cs.NbInternalVariables + cs.NbSecretVariables + cs.NbPublicVariables @@ -78,7 +76,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) } // keep track of wire that have a value - solution, err := newSolution(nbVariables, hintFunctions, cs.Coefficients) + solution, err := newSolution(nbVariables, opt.HintFunctions, cs.Coefficients) if err != nil { return solution.values, err } @@ -94,8 +92,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) solution.nbSolved += len(witness) // defer log printing once all solution.values are computed - // TODO @gbotrel replace stdout by writer set by user, same as in R1CS - defer solution.printLogs(os.Stdout, cs.Logs) + defer solution.printLogs(opt.LoggerOut, cs.Logs) // batch invert the coefficients to avoid many divisions in the solver coefficientsNegInv := fr.BatchInvert(cs.Coefficients) @@ -109,7 +106,11 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) return solution.values, fmt.Errorf("constraint %d: %w", i, err) } if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil { - return solution.values, fmt.Errorf("constraint %d: %w", i, err) + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + } + return solution.values, ErrUnsatisfiedConstraint } } @@ -131,8 +132,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { // check if it's a hint - if hID, ok := cs.mHints[lID]; ok { - if err := solution.solveHint(cs.Hints[hID], lID); err != nil { + if hint, ok := cs.MHints[lID]; ok { + if err := solution.solveWithHint(lID, hint); err != nil { return -1, err } } else { @@ -143,8 +144,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { // check if it's a hint - if hID, ok := cs.mHints[rID]; ok { - if err := solution.solveHint(cs.Hints[hID], rID); err != nil { + if hint, ok := cs.MHints[rID]; ok { + if err := solution.solveWithHint(rID, hint); err != nil { return -1, err } } else { @@ -154,8 +155,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.O.CoeffID() != 0) && !solution.solved[oID] { // check if it's a hint - if hID, ok := cs.mHints[oID]; ok { - if err := solution.solveHint(cs.Hints[hID], oID); err != nil { + if hint, ok := cs.MHints[oID]; ok { + if err := solution.solveWithHint(oID, hint); err != nil { return -1, err } } else { @@ -223,8 +224,8 @@ func (cs *SparseR1CS) solveConstraint(c compiled.SparseR1C, solution *solution, // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps r1cs.Solve() and allocates r1cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) error { - _, err := cs.Solve(witness, hintFunctions) +func (cs *SparseR1CS) IsSolved(witness []fr.Element, opt backend.ProverOption) error { + _, err := cs.Solve(witness, opt) return err } @@ -265,21 +266,12 @@ func (cs *SparseR1CS) ToHTML(w io.Writer) error { return err } - type data struct { - *SparseR1CS - MHints map[int]int - } - d := data{ - cs, - cs.mHints, - } - - return t.Execute(w, &d) + return t.Execute(w, cs) } -func toHTMLTerm(t compiled.Term, coeffs []fr.Element, mHints map[int]int) string { +func toHTMLTerm(t compiled.Term, coeffs []fr.Element, MHints map[int]int) string { var sbb strings.Builder - termToHTML(t, &sbb, coeffs, mHints, true) + termToHTML(t, &sbb, coeffs, MHints, true) return sbb.String() } @@ -295,15 +287,6 @@ func toHTMLCoeff(cID int, coeffs []fr.Element) string { return sbb.String() } -func (cs *SparseR1CS) initHints() { - // we may do that sooner to save time in the solver, but we want the serialized data structures to be - // deterministic, hence avoid maps in there. - cs.mHints = make(map[int]int, len(cs.Hints)) - for i := 0; i < len(cs.Hints); i++ { - cs.mHints[cs.Hints[i].WireID] = i - } -} - // FrSize return fr.Limbs * 8, size in byte of a fr element func (cs *SparseR1CS) FrSize() int { return fr.Limbs * 8 @@ -322,10 +305,14 @@ func (cs *SparseR1CS) CurveID() ecc.ID { // WriteTo encodes SparseR1CS into provided io.Writer using cbor func (cs *SparseR1CS) WriteTo(w io.Writer) (int64, error) { _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - encoder := cbor.NewEncoder(&_w) + enc, err := cbor.CoreDetEncOptions().EncMode() + if err != nil { + return 0, err + } + encoder := enc.NewEncoder(&_w) // encode our object - err := encoder.Encode(cs) + err = encoder.Encode(cs) return _w.N, err } @@ -337,6 +324,12 @@ func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { } decoder := dm.NewDecoder(r) err = decoder.Decode(cs) - cs.initHints() return int64(decoder.NumBytesRead()), err } + +// SetLoggerOutput replace existing logger output with provided one +// default uses os.Stdout +// if nil is provided, logs are not printed +func (cs *SparseR1CS) SetLoggerOutput(w io.Writer) { + cs.loggerOut = w +} diff --git a/internal/backend/bls12-381/cs/r1cs_test.go b/internal/backend/bls12-381/cs/r1cs_test.go index cf34792a3e..a270c89096 100644 --- a/internal/backend/bls12-381/cs/r1cs_test.go +++ b/internal/backend/bls12-381/cs/r1cs_test.go @@ -48,9 +48,6 @@ func TestSerialization(t *testing.T) { t.Fatal(err) } - // no need to serialize. - r1cs.SetLoggerOutput(nil) - r1cs2.SetLoggerOutput(nil) { buffer.Reset() t.Log(name) diff --git a/internal/backend/bls12-381/cs/cs.go b/internal/backend/bls12-381/cs/solution.go similarity index 75% rename from internal/backend/bls12-381/cs/cs.go rename to internal/backend/bls12-381/cs/solution.go index bd5292d077..2744adf6d0 100644 --- a/internal/backend/bls12-381/cs/cs.go +++ b/internal/backend/bls12-381/cs/solution.go @@ -105,7 +105,7 @@ func (s *solution) computeTerm(t compiled.Term) fr.Element { } // solveHint compute solution.values[vID] using provided solver hint -func (s *solution) solveHint(h compiled.Hint, vID int) error { +func (s *solution) solveWithHint(vID int, h compiled.Hint) error { // compute values for all inputs. inputs := make([]fr.Element, len(h.Inputs)) @@ -149,12 +149,65 @@ func (s *solution) printLogs(w io.Writer, logs []compiled.LogEntry) { } } +const unsolvedVariable = "" + func (s *solution) logValue(log compiled.LogEntry) string { var toResolve []interface{} + var ( + isEval bool + eval fr.Element + missingValue bool + ) for j := 0; j < len(log.ToResolve); j++ { - vID := log.ToResolve[j] + if log.ToResolve[j] == compiled.TermDelimitor { + // this is a special case where we want to evaluate the following terms until the next delimitor. + if !isEval { + isEval = true + missingValue = false + eval.SetZero() + continue + } + isEval = false + if missingValue { + toResolve = append(toResolve, unsolvedVariable) + } else { + // we have to append our accumulator + toResolve = append(toResolve, eval.String()) + } + continue + } + cID, vID, visibility := log.ToResolve[j].Unpack() + + if isEval { + // we are evaluating + if visibility == compiled.Virtual { + // just add the constant + eval.Add(&eval, &s.coefficients[cID]) + continue + } + if !s.solved[vID] { + missingValue = true + continue + } + tv := s.computeTerm(log.ToResolve[j]) + eval.Add(&eval, &tv) + continue + } + + if visibility == compiled.Virtual { + // it's just a constant + if cID == compiled.CoeffIdMinusOne { + toResolve = append(toResolve, "-1") + } else { + toResolve = append(toResolve, s.coefficients[cID].String()) + } + continue + } + if !(cID == compiled.CoeffIdMinusOne || cID == compiled.CoeffIdOne) { + toResolve = append(toResolve, s.coefficients[cID].String()) + } if !s.solved[vID] { - toResolve = append(toResolve, "???") + toResolve = append(toResolve, unsolvedVariable) } else { toResolve = append(toResolve, s.values[vID].String()) } diff --git a/internal/backend/bls12-381/groth16/groth16_test.go b/internal/backend/bls12-381/groth16/groth16_test.go index bbdf0274c8..3c948011ed 100644 --- a/internal/backend/bls12-381/groth16/groth16_test.go +++ b/internal/backend/bls12-381/groth16/groth16_test.go @@ -130,7 +130,7 @@ func BenchmarkProver(b *testing.B) { b.ResetTimer() b.Run("prover", func(b *testing.B) { for i := 0; i < b.N; i++ { - _, _ = bls12_381groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + _, _ = bls12_381groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) } }) } @@ -151,7 +151,7 @@ func BenchmarkVerifier(b *testing.B) { var pk bls12_381groth16.ProvingKey var vk bls12_381groth16.VerifyingKey bls12_381groth16.Setup(r1cs.(*cs.R1CS), &pk, &vk) - proof, err := bls12_381groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + proof, err := bls12_381groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } @@ -175,7 +175,7 @@ func BenchmarkSerialization(b *testing.B) { var pk bls12_381groth16.ProvingKey var vk bls12_381groth16.VerifyingKey bls12_381groth16.Setup(r1cs.(*cs.R1CS), &pk, &vk) - proof, err := bls12_381groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + proof, err := bls12_381groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } diff --git a/internal/backend/bls12-381/groth16/prove.go b/internal/backend/bls12-381/groth16/prove.go index c2ad6ec166..eb053866f6 100644 --- a/internal/backend/bls12-381/groth16/prove.go +++ b/internal/backend/bls12-381/groth16/prove.go @@ -27,7 +27,7 @@ import ( "fmt" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" bls12_381witness "github.com/consensys/gnark/internal/backend/bls12-381/witness" "github.com/consensys/gnark/internal/utils" "math/big" @@ -53,21 +53,19 @@ func (proof *Proof) CurveID() ecc.ID { } // Prove generates the proof of knoweldge of a r1cs with full witness (secret + public part). -// if force flag is set, Prove ignores R1CS solving error (ie invalid witness) and executes -// the FFTs and MultiExponentiations to compute an (invalid) Proof object -func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bls12_381witness.Witness, hintFunctions []hint.Function, force bool) (*Proof, error) { +func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bls12_381witness.Witness, opt backend.ProverOption) (*Proof, error) { if len(witness) != int(r1cs.NbPublicVariables-1+r1cs.NbSecretVariables) { return nil, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public - ONE_WIRE) + %d (secret)", len(witness), int(r1cs.NbPublicVariables-1+r1cs.NbSecretVariables), r1cs.NbPublicVariables, r1cs.NbSecretVariables) } // solve the R1CS and compute the a, b, c vectors - a := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) - b := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) - c := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) + a := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) + b := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) + c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) var wireValues []fr.Element var err error - if wireValues, err = r1cs.Solve(witness, a, b, c, hintFunctions); err != nil { - if !force { + if wireValues, err = r1cs.Solve(witness, a, b, c, opt); err != nil { + if !opt.Force { return nil, err } else { // we need to fill wireValues with random values else multi exps don't do much diff --git a/internal/backend/bls12-381/groth16/setup.go b/internal/backend/bls12-381/groth16/setup.go index 71228b06a7..4cc996279c 100644 --- a/internal/backend/bls12-381/groth16/setup.go +++ b/internal/backend/bls12-381/groth16/setup.go @@ -95,7 +95,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables // Setting group for fft - domain := fft.NewDomain(uint64(r1cs.NbConstraints), 1, true) + domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -412,7 +412,7 @@ func sampleToxicWaste() (toxicWaste, error) { func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.NbPublicVariables + r1cs.NbSecretVariables - nbConstraints := r1cs.NbConstraints + nbConstraints := len(r1cs.Constraints) // Setting group for fft domain := fft.NewDomain(uint64(nbConstraints), 1, true) diff --git a/internal/backend/bls12-381/plonk/plonk_test.go b/internal/backend/bls12-381/plonk/plonk_test.go index 502e624e51..2dd84f8001 100644 --- a/internal/backend/bls12-381/plonk/plonk_test.go +++ b/internal/backend/bls12-381/plonk/plonk_test.go @@ -135,7 +135,7 @@ func BenchmarkProver(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, err = bls12_381plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false) + _, err = bls12_381plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { b.Fatal(err) } @@ -160,7 +160,7 @@ func BenchmarkVerifier(b *testing.B) { b.Fatal(err) } - proof, err := bls12_381plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false) + proof, err := bls12_381plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } @@ -184,7 +184,7 @@ func BenchmarkSerialization(b *testing.B) { b.Fatal(err) } - proof, err := bls12_381plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false) + proof, err := bls12_381plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { b.Fatal(err) } diff --git a/internal/backend/bls12-381/plonk/prove.go b/internal/backend/bls12-381/plonk/prove.go index d84d5cc20a..9ab6a7bdca 100644 --- a/internal/backend/bls12-381/plonk/prove.go +++ b/internal/backend/bls12-381/plonk/prove.go @@ -38,7 +38,7 @@ import ( "github.com/consensys/gnark/internal/backend/bls12-381/cs" "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/utils" ) @@ -60,7 +60,7 @@ type Proof struct { } // Prove from the public data -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witness, hintFunctions []hint.Function, force bool) (*Proof, error) { +func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witness, opt backend.ProverOption) (*Proof, error) { // pick a hash function that will be used to derive the challenges hFunc := sha256.New() @@ -74,8 +74,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn // compute the constraint system solution var solution []fr.Element var err error - if solution, err = spr.Solve(fullWitness, hintFunctions); err != nil { - if !force { + if solution, err = spr.Solve(fullWitness, opt); err != nil { + if !opt.Force { return nil, err } else { // we need to fill solution with random values diff --git a/internal/backend/bls24-315/cs/r1cs.go b/internal/backend/bls24-315/cs/r1cs.go index 69b737db99..5943d58eb3 100644 --- a/internal/backend/bls24-315/cs/r1cs.go +++ b/internal/backend/bls24-315/cs/r1cs.go @@ -21,12 +21,11 @@ import ( "fmt" "io" "math/big" - "os" "strings" "github.com/fxamacker/cbor/v2" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" @@ -40,8 +39,6 @@ import ( type R1CS struct { compiled.R1CS Coefficients []fr.Element // R1C coefficients indexes point here - loggerOut io.Writer - mHints map[int]int // correspondance between hint wire ID and hint data struct } // NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values @@ -49,14 +46,11 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { r := R1CS{ R1CS: cs, Coefficients: make([]fr.Element, len(coefficients)), - loggerOut: os.Stdout, } for i := 0; i < len(coefficients); i++ { r.Coefficients[i].SetBigInt(&coefficients[i]) } - r.initHints() - return &r } @@ -65,10 +59,10 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { // a, b, c vectors: ab-c = hz // witness = [publicWires | secretWires] (without the ONE_WIRE !) // returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Function) ([]fr.Element, error) { +func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverOption) ([]fr.Element, error) { nbWires := cs.NbPublicVariables + cs.NbSecretVariables + cs.NbInternalVariables - solution, err := newSolution(nbWires, hintFunctions, cs.Coefficients) + solution, err := newSolution(nbWires, opt.HintFunctions, cs.Coefficients) if err != nil { return make([]fr.Element, nbWires), err } @@ -78,8 +72,8 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi } // compute the wires and the a, b, c polynomials - if len(a) != int(cs.NbConstraints) || len(b) != int(cs.NbConstraints) || len(c) != int(cs.NbConstraints) { - return solution.values, errors.New("invalid input size: len(a, b, c) == cs.NbConstraints") + if len(a) != len(cs.Constraints) || len(b) != len(cs.Constraints) || len(c) != len(cs.Constraints) { + return solution.values, errors.New("invalid input size: len(a, b, c) == len(Constraints)") } solution.solved[0] = true // ONE_WIRE @@ -95,16 +89,11 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // now that we know all inputs are set, defer log printing once all solution.values are computed // (or sooner, if a constraint is not satisfied) - defer solution.printLogs(cs.loggerOut, cs.Logs) + defer solution.printLogs(opt.LoggerOut, cs.Logs) // check if there is an inconsistant constraint var check fr.Element - // TODO @gbotrel clean this - // this variable is used to navigate in the debugInfoComputation slice. - // It is incremented by one each time a division happens for solving a constraint. - var debugInfoComputationOffset uint - // for each constraint // we are guaranteed that each R1C contains at most one unsolved wire // first we solve the unsolved wire (if any) @@ -112,11 +101,10 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // if a[i] * b[i] != c[i]; it means the constraint is not satisfied for i := 0; i < len(cs.Constraints); i++ { // solve the constraint, this will compute the missing wire of the gate - offset, err := cs.solveConstraint(cs.Constraints[i], &solution) - if err != nil { + if err := cs.solveConstraint(cs.Constraints[i], &solution); err != nil { + // TODO should return debug info, if any. return solution.values, err } - debugInfoComputationOffset += offset // compute values for the R1C (ie value * coeff) a[i], b[i], c[i] = cs.instantiateR1C(cs.Constraints[i], &solution) @@ -124,9 +112,11 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // ensure a[i] * b[i] == c[i] check.Mul(&a[i], &b[i]) if !check.Equal(&c[i]) { - debugInfo := cs.DebugInfoComputation[debugInfoComputationOffset] - debugInfoStr := solution.logValue(debugInfo) - return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + } + return solution.values, ErrUnsatisfiedConstraint } } @@ -140,23 +130,14 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) error { - a := make([]fr.Element, cs.NbConstraints) - b := make([]fr.Element, cs.NbConstraints) - c := make([]fr.Element, cs.NbConstraints) - _, err := cs.Solve(witness, a, b, c, hintFunctions) +func (cs *R1CS) IsSolved(witness []fr.Element, opt backend.ProverOption) error { + a := make([]fr.Element, len(cs.Constraints)) + b := make([]fr.Element, len(cs.Constraints)) + c := make([]fr.Element, len(cs.Constraints)) + _, err := cs.Solve(witness, a, b, c, opt) return err } -func (cs *R1CS) initHints() { - // we may do that sooner to save time in the solver, but we want the serialized data structures to be - // deterministic, hence avoid maps in there. - cs.mHints = make(map[int]int, len(cs.Hints)) - for i := 0; i < len(cs.Hints); i++ { - cs.mHints[cs.Hints[i].WireID] = i - } -} - // mulByCoeff sets res = res * t.Coeff func (cs *R1CS) mulByCoeff(res *fr.Element, t compiled.Term) { cID := t.CoeffID() @@ -203,10 +184,7 @@ func (cs *R1CS) instantiateR1C(r compiled.R1C, solution *solution) (a, b, c fr.E // It returns the 1 if the the position to solve is in the quadratic part (it // means that there is a division and serves to navigate in the log info for the // computational constraints), and 0 otherwise. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error) { - - // value to return: 1 if the wire to solve is in the quadratic term, 0 otherwise - var offset uint +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated @@ -226,9 +204,8 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error } // first we check if this is a hint wire - if hID, ok := cs.mHints[vID]; ok { - // TODO handle error - return solution.solveHint(cs.Hints[hID], vID) + if hint, ok := cs.MHints[vID]; ok { + return solution.solveWithHint(vID, hint) } if loc != 0 { @@ -241,19 +218,19 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error for _, t := range r.L { if err := processTerm(t, &a, 1); err != nil { - return 0, err + return err } } for _, t := range r.R { if err := processTerm(t, &b, 2); err != nil { - return 0, err + return err } } for _, t := range r.O { if err := processTerm(t, &c, 3); err != nil { - return 0, err + return err } } @@ -261,7 +238,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return 0, nil + return nil } // we compute the wire value and instantiate it @@ -276,14 +253,12 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error wire.Div(&c, &b). Sub(&wire, &a) cs.mulByCoeff(&wire, termToCompute) - offset = 1 } case 2: if !a.IsZero() { wire.Div(&c, &a). Sub(&wire, &b) cs.mulByCoeff(&wire, termToCompute) - offset = 1 } case 3: wire.Mul(&a, &b). @@ -293,7 +268,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error solution.set(vID, wire) - return offset, nil + return nil } // TODO @gbotrel clean logs and html @@ -309,15 +284,7 @@ func (cs *R1CS) ToHTML(w io.Writer) error { return err } - type data struct { - *R1CS - MHints map[int]int - } - d := data{ - cs, - cs.mHints, - } - return t.Execute(w, &d) + return t.Execute(w, cs) } func add(a, b int) int { @@ -328,10 +295,10 @@ func sub(a, b int) int { return a - b } -func toHTML(l compiled.LinearExpression, coeffs []fr.Element, mHints map[int]int) string { +func toHTML(l compiled.LinearExpression, coeffs []fr.Element, MHints map[int]int) string { var sbb strings.Builder for i := 0; i < len(l); i++ { - termToHTML(l[i], &sbb, coeffs, mHints, false) + termToHTML(l[i], &sbb, coeffs, MHints, false) if i+1 < len(l) { sbb.WriteString(" + ") } @@ -339,7 +306,7 @@ func toHTML(l compiled.LinearExpression, coeffs []fr.Element, mHints map[int]int return sbb.String() } -func termToHTML(t compiled.Term, sbb *strings.Builder, coeffs []fr.Element, mHints map[int]int, offset bool) { +func termToHTML(t compiled.Term, sbb *strings.Builder, coeffs []fr.Element, MHints map[int]int, offset bool) { tID := t.CoeffID() if tID == compiled.CoeffIdOne { // do nothing, just print the variable @@ -360,7 +327,7 @@ func termToHTML(t compiled.Term, sbb *strings.Builder, coeffs []fr.Element, mHin switch t.VariableVisibility() { case compiled.Internal: class = "internal" - if _, ok := mHints[vID]; ok { + if _, ok := MHints[vID]; ok { class = "hint" } case compiled.Public: @@ -396,24 +363,18 @@ func (cs *R1CS) FrSize() int { return fr.Limbs * 8 } -// SetLoggerOutput replace existing logger output with provided one -// default uses os.Stdout -// if nil is provided, logs are not printed -func (cs *R1CS) SetLoggerOutput(w io.Writer) { - cs.loggerOut = w -} - // WriteTo encodes R1CS into provided io.Writer using cbor func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - encoder := cbor.NewEncoder(&_w) - - // encode our object - if err := encoder.Encode(cs); err != nil { - return _w.N, err + enc, err := cbor.CoreDetEncOptions().EncMode() + if err != nil { + return 0, err } + encoder := enc.NewEncoder(&_w) - return _w.N, nil + // encode our object + err = encoder.Encode(cs) + return _w.N, err } // ReadFrom attempts to decode R1CS from io.Reader using cbor @@ -427,8 +388,5 @@ func (cs *R1CS) ReadFrom(r io.Reader) (int64, error) { return int64(decoder.NumBytesRead()), err } - // init the hint map - cs.initHints() - return int64(decoder.NumBytesRead()), nil } diff --git a/internal/backend/bls24-315/cs/r1cs_sparse.go b/internal/backend/bls24-315/cs/r1cs_sparse.go index 36b81c1042..a3def050cb 100644 --- a/internal/backend/bls24-315/cs/r1cs_sparse.go +++ b/internal/backend/bls24-315/cs/r1cs_sparse.go @@ -26,7 +26,7 @@ import ( "strings" "text/template" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" @@ -37,9 +37,8 @@ import ( type SparseR1CS struct { compiled.SparseR1CS - // Coefficients in the constraints - Coefficients []fr.Element // list of unique coefficients. - mHints map[int]int // correspondance between hint wire ID and hint data struct + Coefficients []fr.Element // coefficients in the constraints + loggerOut io.Writer } // NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values @@ -47,13 +46,12 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS cs := SparseR1CS{ SparseR1CS: ccs, Coefficients: make([]fr.Element, len(coefficients)), + loggerOut: os.Stdout, } for i := 0; i < len(coefficients); i++ { cs.Coefficients[i].SetBigInt(&coefficients[i]) } - cs.initHints() - return &cs } @@ -61,7 +59,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS // solution.values = [publicInputs | secretInputs | internalVariables ] // witness: contains the input variables // it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) (values []fr.Element, err error) { +func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverOption) ([]fr.Element, error) { // set the slices holding the solution.values and monitoring which variables have been solved nbVariables := cs.NbInternalVariables + cs.NbSecretVariables + cs.NbPublicVariables @@ -78,7 +76,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) } // keep track of wire that have a value - solution, err := newSolution(nbVariables, hintFunctions, cs.Coefficients) + solution, err := newSolution(nbVariables, opt.HintFunctions, cs.Coefficients) if err != nil { return solution.values, err } @@ -94,8 +92,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) solution.nbSolved += len(witness) // defer log printing once all solution.values are computed - // TODO @gbotrel replace stdout by writer set by user, same as in R1CS - defer solution.printLogs(os.Stdout, cs.Logs) + defer solution.printLogs(opt.LoggerOut, cs.Logs) // batch invert the coefficients to avoid many divisions in the solver coefficientsNegInv := fr.BatchInvert(cs.Coefficients) @@ -109,7 +106,11 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) return solution.values, fmt.Errorf("constraint %d: %w", i, err) } if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil { - return solution.values, fmt.Errorf("constraint %d: %w", i, err) + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + } + return solution.values, ErrUnsatisfiedConstraint } } @@ -131,8 +132,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { // check if it's a hint - if hID, ok := cs.mHints[lID]; ok { - if err := solution.solveHint(cs.Hints[hID], lID); err != nil { + if hint, ok := cs.MHints[lID]; ok { + if err := solution.solveWithHint(lID, hint); err != nil { return -1, err } } else { @@ -143,8 +144,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { // check if it's a hint - if hID, ok := cs.mHints[rID]; ok { - if err := solution.solveHint(cs.Hints[hID], rID); err != nil { + if hint, ok := cs.MHints[rID]; ok { + if err := solution.solveWithHint(rID, hint); err != nil { return -1, err } } else { @@ -154,8 +155,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.O.CoeffID() != 0) && !solution.solved[oID] { // check if it's a hint - if hID, ok := cs.mHints[oID]; ok { - if err := solution.solveHint(cs.Hints[hID], oID); err != nil { + if hint, ok := cs.MHints[oID]; ok { + if err := solution.solveWithHint(oID, hint); err != nil { return -1, err } } else { @@ -223,8 +224,8 @@ func (cs *SparseR1CS) solveConstraint(c compiled.SparseR1C, solution *solution, // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps r1cs.Solve() and allocates r1cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) error { - _, err := cs.Solve(witness, hintFunctions) +func (cs *SparseR1CS) IsSolved(witness []fr.Element, opt backend.ProverOption) error { + _, err := cs.Solve(witness, opt) return err } @@ -265,21 +266,12 @@ func (cs *SparseR1CS) ToHTML(w io.Writer) error { return err } - type data struct { - *SparseR1CS - MHints map[int]int - } - d := data{ - cs, - cs.mHints, - } - - return t.Execute(w, &d) + return t.Execute(w, cs) } -func toHTMLTerm(t compiled.Term, coeffs []fr.Element, mHints map[int]int) string { +func toHTMLTerm(t compiled.Term, coeffs []fr.Element, MHints map[int]int) string { var sbb strings.Builder - termToHTML(t, &sbb, coeffs, mHints, true) + termToHTML(t, &sbb, coeffs, MHints, true) return sbb.String() } @@ -295,15 +287,6 @@ func toHTMLCoeff(cID int, coeffs []fr.Element) string { return sbb.String() } -func (cs *SparseR1CS) initHints() { - // we may do that sooner to save time in the solver, but we want the serialized data structures to be - // deterministic, hence avoid maps in there. - cs.mHints = make(map[int]int, len(cs.Hints)) - for i := 0; i < len(cs.Hints); i++ { - cs.mHints[cs.Hints[i].WireID] = i - } -} - // FrSize return fr.Limbs * 8, size in byte of a fr element func (cs *SparseR1CS) FrSize() int { return fr.Limbs * 8 @@ -322,10 +305,14 @@ func (cs *SparseR1CS) CurveID() ecc.ID { // WriteTo encodes SparseR1CS into provided io.Writer using cbor func (cs *SparseR1CS) WriteTo(w io.Writer) (int64, error) { _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - encoder := cbor.NewEncoder(&_w) + enc, err := cbor.CoreDetEncOptions().EncMode() + if err != nil { + return 0, err + } + encoder := enc.NewEncoder(&_w) // encode our object - err := encoder.Encode(cs) + err = encoder.Encode(cs) return _w.N, err } @@ -337,6 +324,12 @@ func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { } decoder := dm.NewDecoder(r) err = decoder.Decode(cs) - cs.initHints() return int64(decoder.NumBytesRead()), err } + +// SetLoggerOutput replace existing logger output with provided one +// default uses os.Stdout +// if nil is provided, logs are not printed +func (cs *SparseR1CS) SetLoggerOutput(w io.Writer) { + cs.loggerOut = w +} diff --git a/internal/backend/bls24-315/cs/r1cs_test.go b/internal/backend/bls24-315/cs/r1cs_test.go index 51dc21d66c..6608db8102 100644 --- a/internal/backend/bls24-315/cs/r1cs_test.go +++ b/internal/backend/bls24-315/cs/r1cs_test.go @@ -48,9 +48,6 @@ func TestSerialization(t *testing.T) { t.Fatal(err) } - // no need to serialize. - r1cs.SetLoggerOutput(nil) - r1cs2.SetLoggerOutput(nil) { buffer.Reset() t.Log(name) diff --git a/internal/backend/bls24-315/cs/cs.go b/internal/backend/bls24-315/cs/solution.go similarity index 75% rename from internal/backend/bls24-315/cs/cs.go rename to internal/backend/bls24-315/cs/solution.go index d260069445..f2cc9e8d08 100644 --- a/internal/backend/bls24-315/cs/cs.go +++ b/internal/backend/bls24-315/cs/solution.go @@ -105,7 +105,7 @@ func (s *solution) computeTerm(t compiled.Term) fr.Element { } // solveHint compute solution.values[vID] using provided solver hint -func (s *solution) solveHint(h compiled.Hint, vID int) error { +func (s *solution) solveWithHint(vID int, h compiled.Hint) error { // compute values for all inputs. inputs := make([]fr.Element, len(h.Inputs)) @@ -149,12 +149,65 @@ func (s *solution) printLogs(w io.Writer, logs []compiled.LogEntry) { } } +const unsolvedVariable = "" + func (s *solution) logValue(log compiled.LogEntry) string { var toResolve []interface{} + var ( + isEval bool + eval fr.Element + missingValue bool + ) for j := 0; j < len(log.ToResolve); j++ { - vID := log.ToResolve[j] + if log.ToResolve[j] == compiled.TermDelimitor { + // this is a special case where we want to evaluate the following terms until the next delimitor. + if !isEval { + isEval = true + missingValue = false + eval.SetZero() + continue + } + isEval = false + if missingValue { + toResolve = append(toResolve, unsolvedVariable) + } else { + // we have to append our accumulator + toResolve = append(toResolve, eval.String()) + } + continue + } + cID, vID, visibility := log.ToResolve[j].Unpack() + + if isEval { + // we are evaluating + if visibility == compiled.Virtual { + // just add the constant + eval.Add(&eval, &s.coefficients[cID]) + continue + } + if !s.solved[vID] { + missingValue = true + continue + } + tv := s.computeTerm(log.ToResolve[j]) + eval.Add(&eval, &tv) + continue + } + + if visibility == compiled.Virtual { + // it's just a constant + if cID == compiled.CoeffIdMinusOne { + toResolve = append(toResolve, "-1") + } else { + toResolve = append(toResolve, s.coefficients[cID].String()) + } + continue + } + if !(cID == compiled.CoeffIdMinusOne || cID == compiled.CoeffIdOne) { + toResolve = append(toResolve, s.coefficients[cID].String()) + } if !s.solved[vID] { - toResolve = append(toResolve, "???") + toResolve = append(toResolve, unsolvedVariable) } else { toResolve = append(toResolve, s.values[vID].String()) } diff --git a/internal/backend/bls24-315/groth16/groth16_test.go b/internal/backend/bls24-315/groth16/groth16_test.go index 9ed77d4fc1..5053d1d88e 100644 --- a/internal/backend/bls24-315/groth16/groth16_test.go +++ b/internal/backend/bls24-315/groth16/groth16_test.go @@ -130,7 +130,7 @@ func BenchmarkProver(b *testing.B) { b.ResetTimer() b.Run("prover", func(b *testing.B) { for i := 0; i < b.N; i++ { - _, _ = bls24_315groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + _, _ = bls24_315groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) } }) } @@ -151,7 +151,7 @@ func BenchmarkVerifier(b *testing.B) { var pk bls24_315groth16.ProvingKey var vk bls24_315groth16.VerifyingKey bls24_315groth16.Setup(r1cs.(*cs.R1CS), &pk, &vk) - proof, err := bls24_315groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + proof, err := bls24_315groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } @@ -175,7 +175,7 @@ func BenchmarkSerialization(b *testing.B) { var pk bls24_315groth16.ProvingKey var vk bls24_315groth16.VerifyingKey bls24_315groth16.Setup(r1cs.(*cs.R1CS), &pk, &vk) - proof, err := bls24_315groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + proof, err := bls24_315groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } diff --git a/internal/backend/bls24-315/groth16/prove.go b/internal/backend/bls24-315/groth16/prove.go index deaad3f871..916be2bf8e 100644 --- a/internal/backend/bls24-315/groth16/prove.go +++ b/internal/backend/bls24-315/groth16/prove.go @@ -27,7 +27,7 @@ import ( "fmt" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" bls24_315witness "github.com/consensys/gnark/internal/backend/bls24-315/witness" "github.com/consensys/gnark/internal/utils" "math/big" @@ -53,21 +53,19 @@ func (proof *Proof) CurveID() ecc.ID { } // Prove generates the proof of knoweldge of a r1cs with full witness (secret + public part). -// if force flag is set, Prove ignores R1CS solving error (ie invalid witness) and executes -// the FFTs and MultiExponentiations to compute an (invalid) Proof object -func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bls24_315witness.Witness, hintFunctions []hint.Function, force bool) (*Proof, error) { +func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bls24_315witness.Witness, opt backend.ProverOption) (*Proof, error) { if len(witness) != int(r1cs.NbPublicVariables-1+r1cs.NbSecretVariables) { return nil, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public - ONE_WIRE) + %d (secret)", len(witness), int(r1cs.NbPublicVariables-1+r1cs.NbSecretVariables), r1cs.NbPublicVariables, r1cs.NbSecretVariables) } // solve the R1CS and compute the a, b, c vectors - a := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) - b := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) - c := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) + a := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) + b := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) + c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) var wireValues []fr.Element var err error - if wireValues, err = r1cs.Solve(witness, a, b, c, hintFunctions); err != nil { - if !force { + if wireValues, err = r1cs.Solve(witness, a, b, c, opt); err != nil { + if !opt.Force { return nil, err } else { // we need to fill wireValues with random values else multi exps don't do much diff --git a/internal/backend/bls24-315/groth16/setup.go b/internal/backend/bls24-315/groth16/setup.go index 9c5a225cb6..ec5ebda7ec 100644 --- a/internal/backend/bls24-315/groth16/setup.go +++ b/internal/backend/bls24-315/groth16/setup.go @@ -95,7 +95,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables // Setting group for fft - domain := fft.NewDomain(uint64(r1cs.NbConstraints), 1, true) + domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -412,7 +412,7 @@ func sampleToxicWaste() (toxicWaste, error) { func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.NbPublicVariables + r1cs.NbSecretVariables - nbConstraints := r1cs.NbConstraints + nbConstraints := len(r1cs.Constraints) // Setting group for fft domain := fft.NewDomain(uint64(nbConstraints), 1, true) diff --git a/internal/backend/bls24-315/plonk/plonk_test.go b/internal/backend/bls24-315/plonk/plonk_test.go index f823c8cb18..94497ead20 100644 --- a/internal/backend/bls24-315/plonk/plonk_test.go +++ b/internal/backend/bls24-315/plonk/plonk_test.go @@ -135,7 +135,7 @@ func BenchmarkProver(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, err = bls24_315plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false) + _, err = bls24_315plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { b.Fatal(err) } @@ -160,7 +160,7 @@ func BenchmarkVerifier(b *testing.B) { b.Fatal(err) } - proof, err := bls24_315plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false) + proof, err := bls24_315plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } @@ -184,7 +184,7 @@ func BenchmarkSerialization(b *testing.B) { b.Fatal(err) } - proof, err := bls24_315plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false) + proof, err := bls24_315plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { b.Fatal(err) } diff --git a/internal/backend/bls24-315/plonk/prove.go b/internal/backend/bls24-315/plonk/prove.go index 9ad9b3741e..a772beec84 100644 --- a/internal/backend/bls24-315/plonk/prove.go +++ b/internal/backend/bls24-315/plonk/prove.go @@ -38,7 +38,7 @@ import ( "github.com/consensys/gnark/internal/backend/bls24-315/cs" "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/utils" ) @@ -60,7 +60,7 @@ type Proof struct { } // Prove from the public data -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witness, hintFunctions []hint.Function, force bool) (*Proof, error) { +func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witness, opt backend.ProverOption) (*Proof, error) { // pick a hash function that will be used to derive the challenges hFunc := sha256.New() @@ -74,8 +74,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn // compute the constraint system solution var solution []fr.Element var err error - if solution, err = spr.Solve(fullWitness, hintFunctions); err != nil { - if !force { + if solution, err = spr.Solve(fullWitness, opt); err != nil { + if !opt.Force { return nil, err } else { // we need to fill solution with random values diff --git a/internal/backend/bn254/cs/r1cs.go b/internal/backend/bn254/cs/r1cs.go index 906d480e87..36a9a3dd23 100644 --- a/internal/backend/bn254/cs/r1cs.go +++ b/internal/backend/bn254/cs/r1cs.go @@ -21,12 +21,11 @@ import ( "fmt" "io" "math/big" - "os" "strings" "github.com/fxamacker/cbor/v2" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" @@ -40,8 +39,6 @@ import ( type R1CS struct { compiled.R1CS Coefficients []fr.Element // R1C coefficients indexes point here - loggerOut io.Writer - mHints map[int]int // correspondance between hint wire ID and hint data struct } // NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values @@ -49,14 +46,11 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { r := R1CS{ R1CS: cs, Coefficients: make([]fr.Element, len(coefficients)), - loggerOut: os.Stdout, } for i := 0; i < len(coefficients); i++ { r.Coefficients[i].SetBigInt(&coefficients[i]) } - r.initHints() - return &r } @@ -65,10 +59,10 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { // a, b, c vectors: ab-c = hz // witness = [publicWires | secretWires] (without the ONE_WIRE !) // returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Function) ([]fr.Element, error) { +func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverOption) ([]fr.Element, error) { nbWires := cs.NbPublicVariables + cs.NbSecretVariables + cs.NbInternalVariables - solution, err := newSolution(nbWires, hintFunctions, cs.Coefficients) + solution, err := newSolution(nbWires, opt.HintFunctions, cs.Coefficients) if err != nil { return make([]fr.Element, nbWires), err } @@ -78,8 +72,8 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi } // compute the wires and the a, b, c polynomials - if len(a) != int(cs.NbConstraints) || len(b) != int(cs.NbConstraints) || len(c) != int(cs.NbConstraints) { - return solution.values, errors.New("invalid input size: len(a, b, c) == cs.NbConstraints") + if len(a) != len(cs.Constraints) || len(b) != len(cs.Constraints) || len(c) != len(cs.Constraints) { + return solution.values, errors.New("invalid input size: len(a, b, c) == len(Constraints)") } solution.solved[0] = true // ONE_WIRE @@ -95,16 +89,11 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // now that we know all inputs are set, defer log printing once all solution.values are computed // (or sooner, if a constraint is not satisfied) - defer solution.printLogs(cs.loggerOut, cs.Logs) + defer solution.printLogs(opt.LoggerOut, cs.Logs) // check if there is an inconsistant constraint var check fr.Element - // TODO @gbotrel clean this - // this variable is used to navigate in the debugInfoComputation slice. - // It is incremented by one each time a division happens for solving a constraint. - var debugInfoComputationOffset uint - // for each constraint // we are guaranteed that each R1C contains at most one unsolved wire // first we solve the unsolved wire (if any) @@ -112,11 +101,10 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // if a[i] * b[i] != c[i]; it means the constraint is not satisfied for i := 0; i < len(cs.Constraints); i++ { // solve the constraint, this will compute the missing wire of the gate - offset, err := cs.solveConstraint(cs.Constraints[i], &solution) - if err != nil { + if err := cs.solveConstraint(cs.Constraints[i], &solution); err != nil { + // TODO should return debug info, if any. return solution.values, err } - debugInfoComputationOffset += offset // compute values for the R1C (ie value * coeff) a[i], b[i], c[i] = cs.instantiateR1C(cs.Constraints[i], &solution) @@ -124,9 +112,11 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // ensure a[i] * b[i] == c[i] check.Mul(&a[i], &b[i]) if !check.Equal(&c[i]) { - debugInfo := cs.DebugInfoComputation[debugInfoComputationOffset] - debugInfoStr := solution.logValue(debugInfo) - return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + } + return solution.values, ErrUnsatisfiedConstraint } } @@ -140,23 +130,14 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) error { - a := make([]fr.Element, cs.NbConstraints) - b := make([]fr.Element, cs.NbConstraints) - c := make([]fr.Element, cs.NbConstraints) - _, err := cs.Solve(witness, a, b, c, hintFunctions) +func (cs *R1CS) IsSolved(witness []fr.Element, opt backend.ProverOption) error { + a := make([]fr.Element, len(cs.Constraints)) + b := make([]fr.Element, len(cs.Constraints)) + c := make([]fr.Element, len(cs.Constraints)) + _, err := cs.Solve(witness, a, b, c, opt) return err } -func (cs *R1CS) initHints() { - // we may do that sooner to save time in the solver, but we want the serialized data structures to be - // deterministic, hence avoid maps in there. - cs.mHints = make(map[int]int, len(cs.Hints)) - for i := 0; i < len(cs.Hints); i++ { - cs.mHints[cs.Hints[i].WireID] = i - } -} - // mulByCoeff sets res = res * t.Coeff func (cs *R1CS) mulByCoeff(res *fr.Element, t compiled.Term) { cID := t.CoeffID() @@ -203,10 +184,7 @@ func (cs *R1CS) instantiateR1C(r compiled.R1C, solution *solution) (a, b, c fr.E // It returns the 1 if the the position to solve is in the quadratic part (it // means that there is a division and serves to navigate in the log info for the // computational constraints), and 0 otherwise. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error) { - - // value to return: 1 if the wire to solve is in the quadratic term, 0 otherwise - var offset uint +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated @@ -226,9 +204,8 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error } // first we check if this is a hint wire - if hID, ok := cs.mHints[vID]; ok { - // TODO handle error - return solution.solveHint(cs.Hints[hID], vID) + if hint, ok := cs.MHints[vID]; ok { + return solution.solveWithHint(vID, hint) } if loc != 0 { @@ -241,19 +218,19 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error for _, t := range r.L { if err := processTerm(t, &a, 1); err != nil { - return 0, err + return err } } for _, t := range r.R { if err := processTerm(t, &b, 2); err != nil { - return 0, err + return err } } for _, t := range r.O { if err := processTerm(t, &c, 3); err != nil { - return 0, err + return err } } @@ -261,7 +238,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return 0, nil + return nil } // we compute the wire value and instantiate it @@ -276,14 +253,12 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error wire.Div(&c, &b). Sub(&wire, &a) cs.mulByCoeff(&wire, termToCompute) - offset = 1 } case 2: if !a.IsZero() { wire.Div(&c, &a). Sub(&wire, &b) cs.mulByCoeff(&wire, termToCompute) - offset = 1 } case 3: wire.Mul(&a, &b). @@ -293,7 +268,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error solution.set(vID, wire) - return offset, nil + return nil } // TODO @gbotrel clean logs and html @@ -309,15 +284,7 @@ func (cs *R1CS) ToHTML(w io.Writer) error { return err } - type data struct { - *R1CS - MHints map[int]int - } - d := data{ - cs, - cs.mHints, - } - return t.Execute(w, &d) + return t.Execute(w, cs) } func add(a, b int) int { @@ -328,10 +295,10 @@ func sub(a, b int) int { return a - b } -func toHTML(l compiled.LinearExpression, coeffs []fr.Element, mHints map[int]int) string { +func toHTML(l compiled.LinearExpression, coeffs []fr.Element, MHints map[int]int) string { var sbb strings.Builder for i := 0; i < len(l); i++ { - termToHTML(l[i], &sbb, coeffs, mHints, false) + termToHTML(l[i], &sbb, coeffs, MHints, false) if i+1 < len(l) { sbb.WriteString(" + ") } @@ -339,7 +306,7 @@ func toHTML(l compiled.LinearExpression, coeffs []fr.Element, mHints map[int]int return sbb.String() } -func termToHTML(t compiled.Term, sbb *strings.Builder, coeffs []fr.Element, mHints map[int]int, offset bool) { +func termToHTML(t compiled.Term, sbb *strings.Builder, coeffs []fr.Element, MHints map[int]int, offset bool) { tID := t.CoeffID() if tID == compiled.CoeffIdOne { // do nothing, just print the variable @@ -360,7 +327,7 @@ func termToHTML(t compiled.Term, sbb *strings.Builder, coeffs []fr.Element, mHin switch t.VariableVisibility() { case compiled.Internal: class = "internal" - if _, ok := mHints[vID]; ok { + if _, ok := MHints[vID]; ok { class = "hint" } case compiled.Public: @@ -396,24 +363,18 @@ func (cs *R1CS) FrSize() int { return fr.Limbs * 8 } -// SetLoggerOutput replace existing logger output with provided one -// default uses os.Stdout -// if nil is provided, logs are not printed -func (cs *R1CS) SetLoggerOutput(w io.Writer) { - cs.loggerOut = w -} - // WriteTo encodes R1CS into provided io.Writer using cbor func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - encoder := cbor.NewEncoder(&_w) - - // encode our object - if err := encoder.Encode(cs); err != nil { - return _w.N, err + enc, err := cbor.CoreDetEncOptions().EncMode() + if err != nil { + return 0, err } + encoder := enc.NewEncoder(&_w) - return _w.N, nil + // encode our object + err = encoder.Encode(cs) + return _w.N, err } // ReadFrom attempts to decode R1CS from io.Reader using cbor @@ -427,8 +388,5 @@ func (cs *R1CS) ReadFrom(r io.Reader) (int64, error) { return int64(decoder.NumBytesRead()), err } - // init the hint map - cs.initHints() - return int64(decoder.NumBytesRead()), nil } diff --git a/internal/backend/bn254/cs/r1cs_sparse.go b/internal/backend/bn254/cs/r1cs_sparse.go index 4fedb6f24d..4fdc2d163c 100644 --- a/internal/backend/bn254/cs/r1cs_sparse.go +++ b/internal/backend/bn254/cs/r1cs_sparse.go @@ -26,7 +26,7 @@ import ( "strings" "text/template" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" @@ -37,9 +37,8 @@ import ( type SparseR1CS struct { compiled.SparseR1CS - // Coefficients in the constraints - Coefficients []fr.Element // list of unique coefficients. - mHints map[int]int // correspondance between hint wire ID and hint data struct + Coefficients []fr.Element // coefficients in the constraints + loggerOut io.Writer } // NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values @@ -47,13 +46,12 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS cs := SparseR1CS{ SparseR1CS: ccs, Coefficients: make([]fr.Element, len(coefficients)), + loggerOut: os.Stdout, } for i := 0; i < len(coefficients); i++ { cs.Coefficients[i].SetBigInt(&coefficients[i]) } - cs.initHints() - return &cs } @@ -61,7 +59,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS // solution.values = [publicInputs | secretInputs | internalVariables ] // witness: contains the input variables // it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) (values []fr.Element, err error) { +func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverOption) ([]fr.Element, error) { // set the slices holding the solution.values and monitoring which variables have been solved nbVariables := cs.NbInternalVariables + cs.NbSecretVariables + cs.NbPublicVariables @@ -78,7 +76,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) } // keep track of wire that have a value - solution, err := newSolution(nbVariables, hintFunctions, cs.Coefficients) + solution, err := newSolution(nbVariables, opt.HintFunctions, cs.Coefficients) if err != nil { return solution.values, err } @@ -94,8 +92,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) solution.nbSolved += len(witness) // defer log printing once all solution.values are computed - // TODO @gbotrel replace stdout by writer set by user, same as in R1CS - defer solution.printLogs(os.Stdout, cs.Logs) + defer solution.printLogs(opt.LoggerOut, cs.Logs) // batch invert the coefficients to avoid many divisions in the solver coefficientsNegInv := fr.BatchInvert(cs.Coefficients) @@ -109,7 +106,11 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) return solution.values, fmt.Errorf("constraint %d: %w", i, err) } if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil { - return solution.values, fmt.Errorf("constraint %d: %w", i, err) + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + } + return solution.values, ErrUnsatisfiedConstraint } } @@ -131,8 +132,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { // check if it's a hint - if hID, ok := cs.mHints[lID]; ok { - if err := solution.solveHint(cs.Hints[hID], lID); err != nil { + if hint, ok := cs.MHints[lID]; ok { + if err := solution.solveWithHint(lID, hint); err != nil { return -1, err } } else { @@ -143,8 +144,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { // check if it's a hint - if hID, ok := cs.mHints[rID]; ok { - if err := solution.solveHint(cs.Hints[hID], rID); err != nil { + if hint, ok := cs.MHints[rID]; ok { + if err := solution.solveWithHint(rID, hint); err != nil { return -1, err } } else { @@ -154,8 +155,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.O.CoeffID() != 0) && !solution.solved[oID] { // check if it's a hint - if hID, ok := cs.mHints[oID]; ok { - if err := solution.solveHint(cs.Hints[hID], oID); err != nil { + if hint, ok := cs.MHints[oID]; ok { + if err := solution.solveWithHint(oID, hint); err != nil { return -1, err } } else { @@ -223,8 +224,8 @@ func (cs *SparseR1CS) solveConstraint(c compiled.SparseR1C, solution *solution, // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps r1cs.Solve() and allocates r1cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) error { - _, err := cs.Solve(witness, hintFunctions) +func (cs *SparseR1CS) IsSolved(witness []fr.Element, opt backend.ProverOption) error { + _, err := cs.Solve(witness, opt) return err } @@ -265,21 +266,12 @@ func (cs *SparseR1CS) ToHTML(w io.Writer) error { return err } - type data struct { - *SparseR1CS - MHints map[int]int - } - d := data{ - cs, - cs.mHints, - } - - return t.Execute(w, &d) + return t.Execute(w, cs) } -func toHTMLTerm(t compiled.Term, coeffs []fr.Element, mHints map[int]int) string { +func toHTMLTerm(t compiled.Term, coeffs []fr.Element, MHints map[int]int) string { var sbb strings.Builder - termToHTML(t, &sbb, coeffs, mHints, true) + termToHTML(t, &sbb, coeffs, MHints, true) return sbb.String() } @@ -295,15 +287,6 @@ func toHTMLCoeff(cID int, coeffs []fr.Element) string { return sbb.String() } -func (cs *SparseR1CS) initHints() { - // we may do that sooner to save time in the solver, but we want the serialized data structures to be - // deterministic, hence avoid maps in there. - cs.mHints = make(map[int]int, len(cs.Hints)) - for i := 0; i < len(cs.Hints); i++ { - cs.mHints[cs.Hints[i].WireID] = i - } -} - // FrSize return fr.Limbs * 8, size in byte of a fr element func (cs *SparseR1CS) FrSize() int { return fr.Limbs * 8 @@ -322,10 +305,14 @@ func (cs *SparseR1CS) CurveID() ecc.ID { // WriteTo encodes SparseR1CS into provided io.Writer using cbor func (cs *SparseR1CS) WriteTo(w io.Writer) (int64, error) { _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - encoder := cbor.NewEncoder(&_w) + enc, err := cbor.CoreDetEncOptions().EncMode() + if err != nil { + return 0, err + } + encoder := enc.NewEncoder(&_w) // encode our object - err := encoder.Encode(cs) + err = encoder.Encode(cs) return _w.N, err } @@ -337,6 +324,12 @@ func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { } decoder := dm.NewDecoder(r) err = decoder.Decode(cs) - cs.initHints() return int64(decoder.NumBytesRead()), err } + +// SetLoggerOutput replace existing logger output with provided one +// default uses os.Stdout +// if nil is provided, logs are not printed +func (cs *SparseR1CS) SetLoggerOutput(w io.Writer) { + cs.loggerOut = w +} diff --git a/internal/backend/bn254/cs/r1cs_test.go b/internal/backend/bn254/cs/r1cs_test.go index bdda6cea67..146faec6f6 100644 --- a/internal/backend/bn254/cs/r1cs_test.go +++ b/internal/backend/bn254/cs/r1cs_test.go @@ -48,9 +48,6 @@ func TestSerialization(t *testing.T) { t.Fatal(err) } - // no need to serialize. - r1cs.SetLoggerOutput(nil) - r1cs2.SetLoggerOutput(nil) { buffer.Reset() t.Log(name) diff --git a/internal/backend/bn254/cs/cs.go b/internal/backend/bn254/cs/solution.go similarity index 75% rename from internal/backend/bn254/cs/cs.go rename to internal/backend/bn254/cs/solution.go index 00da483db7..e46aebcc25 100644 --- a/internal/backend/bn254/cs/cs.go +++ b/internal/backend/bn254/cs/solution.go @@ -105,7 +105,7 @@ func (s *solution) computeTerm(t compiled.Term) fr.Element { } // solveHint compute solution.values[vID] using provided solver hint -func (s *solution) solveHint(h compiled.Hint, vID int) error { +func (s *solution) solveWithHint(vID int, h compiled.Hint) error { // compute values for all inputs. inputs := make([]fr.Element, len(h.Inputs)) @@ -149,12 +149,65 @@ func (s *solution) printLogs(w io.Writer, logs []compiled.LogEntry) { } } +const unsolvedVariable = "" + func (s *solution) logValue(log compiled.LogEntry) string { var toResolve []interface{} + var ( + isEval bool + eval fr.Element + missingValue bool + ) for j := 0; j < len(log.ToResolve); j++ { - vID := log.ToResolve[j] + if log.ToResolve[j] == compiled.TermDelimitor { + // this is a special case where we want to evaluate the following terms until the next delimitor. + if !isEval { + isEval = true + missingValue = false + eval.SetZero() + continue + } + isEval = false + if missingValue { + toResolve = append(toResolve, unsolvedVariable) + } else { + // we have to append our accumulator + toResolve = append(toResolve, eval.String()) + } + continue + } + cID, vID, visibility := log.ToResolve[j].Unpack() + + if isEval { + // we are evaluating + if visibility == compiled.Virtual { + // just add the constant + eval.Add(&eval, &s.coefficients[cID]) + continue + } + if !s.solved[vID] { + missingValue = true + continue + } + tv := s.computeTerm(log.ToResolve[j]) + eval.Add(&eval, &tv) + continue + } + + if visibility == compiled.Virtual { + // it's just a constant + if cID == compiled.CoeffIdMinusOne { + toResolve = append(toResolve, "-1") + } else { + toResolve = append(toResolve, s.coefficients[cID].String()) + } + continue + } + if !(cID == compiled.CoeffIdMinusOne || cID == compiled.CoeffIdOne) { + toResolve = append(toResolve, s.coefficients[cID].String()) + } if !s.solved[vID] { - toResolve = append(toResolve, "???") + toResolve = append(toResolve, unsolvedVariable) } else { toResolve = append(toResolve, s.values[vID].String()) } diff --git a/internal/backend/bn254/groth16/groth16_test.go b/internal/backend/bn254/groth16/groth16_test.go index 884cd99dfb..a59471e6c0 100644 --- a/internal/backend/bn254/groth16/groth16_test.go +++ b/internal/backend/bn254/groth16/groth16_test.go @@ -130,7 +130,7 @@ func BenchmarkProver(b *testing.B) { b.ResetTimer() b.Run("prover", func(b *testing.B) { for i := 0; i < b.N; i++ { - _, _ = bn254groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + _, _ = bn254groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) } }) } @@ -151,7 +151,7 @@ func BenchmarkVerifier(b *testing.B) { var pk bn254groth16.ProvingKey var vk bn254groth16.VerifyingKey bn254groth16.Setup(r1cs.(*cs.R1CS), &pk, &vk) - proof, err := bn254groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + proof, err := bn254groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } @@ -175,7 +175,7 @@ func BenchmarkSerialization(b *testing.B) { var pk bn254groth16.ProvingKey var vk bn254groth16.VerifyingKey bn254groth16.Setup(r1cs.(*cs.R1CS), &pk, &vk) - proof, err := bn254groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + proof, err := bn254groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } diff --git a/internal/backend/bn254/groth16/prove.go b/internal/backend/bn254/groth16/prove.go index aac9da3b2a..5fcbef9f6a 100644 --- a/internal/backend/bn254/groth16/prove.go +++ b/internal/backend/bn254/groth16/prove.go @@ -27,7 +27,7 @@ import ( "fmt" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" bn254witness "github.com/consensys/gnark/internal/backend/bn254/witness" "github.com/consensys/gnark/internal/utils" "math/big" @@ -53,21 +53,19 @@ func (proof *Proof) CurveID() ecc.ID { } // Prove generates the proof of knoweldge of a r1cs with full witness (secret + public part). -// if force flag is set, Prove ignores R1CS solving error (ie invalid witness) and executes -// the FFTs and MultiExponentiations to compute an (invalid) Proof object -func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bn254witness.Witness, hintFunctions []hint.Function, force bool) (*Proof, error) { +func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bn254witness.Witness, opt backend.ProverOption) (*Proof, error) { if len(witness) != int(r1cs.NbPublicVariables-1+r1cs.NbSecretVariables) { return nil, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public - ONE_WIRE) + %d (secret)", len(witness), int(r1cs.NbPublicVariables-1+r1cs.NbSecretVariables), r1cs.NbPublicVariables, r1cs.NbSecretVariables) } // solve the R1CS and compute the a, b, c vectors - a := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) - b := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) - c := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) + a := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) + b := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) + c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) var wireValues []fr.Element var err error - if wireValues, err = r1cs.Solve(witness, a, b, c, hintFunctions); err != nil { - if !force { + if wireValues, err = r1cs.Solve(witness, a, b, c, opt); err != nil { + if !opt.Force { return nil, err } else { // we need to fill wireValues with random values else multi exps don't do much diff --git a/internal/backend/bn254/groth16/setup.go b/internal/backend/bn254/groth16/setup.go index 97e00dcf86..dc710839d3 100644 --- a/internal/backend/bn254/groth16/setup.go +++ b/internal/backend/bn254/groth16/setup.go @@ -95,7 +95,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables // Setting group for fft - domain := fft.NewDomain(uint64(r1cs.NbConstraints), 1, true) + domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -412,7 +412,7 @@ func sampleToxicWaste() (toxicWaste, error) { func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.NbPublicVariables + r1cs.NbSecretVariables - nbConstraints := r1cs.NbConstraints + nbConstraints := len(r1cs.Constraints) // Setting group for fft domain := fft.NewDomain(uint64(nbConstraints), 1, true) diff --git a/internal/backend/bn254/plonk/plonk_test.go b/internal/backend/bn254/plonk/plonk_test.go index 58a3e3d0cd..6e26a1d314 100644 --- a/internal/backend/bn254/plonk/plonk_test.go +++ b/internal/backend/bn254/plonk/plonk_test.go @@ -135,7 +135,7 @@ func BenchmarkProver(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, err = bn254plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false) + _, err = bn254plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { b.Fatal(err) } @@ -160,7 +160,7 @@ func BenchmarkVerifier(b *testing.B) { b.Fatal(err) } - proof, err := bn254plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false) + proof, err := bn254plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } @@ -184,7 +184,7 @@ func BenchmarkSerialization(b *testing.B) { b.Fatal(err) } - proof, err := bn254plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false) + proof, err := bn254plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { b.Fatal(err) } diff --git a/internal/backend/bn254/plonk/prove.go b/internal/backend/bn254/plonk/prove.go index 1c3ccc4e7d..0421b01dee 100644 --- a/internal/backend/bn254/plonk/prove.go +++ b/internal/backend/bn254/plonk/prove.go @@ -38,7 +38,7 @@ import ( "github.com/consensys/gnark/internal/backend/bn254/cs" "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/utils" ) @@ -60,7 +60,7 @@ type Proof struct { } // Prove from the public data -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, hintFunctions []hint.Function, force bool) (*Proof, error) { +func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, opt backend.ProverOption) (*Proof, error) { // pick a hash function that will be used to derive the challenges hFunc := sha256.New() @@ -74,8 +74,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, // compute the constraint system solution var solution []fr.Element var err error - if solution, err = spr.Solve(fullWitness, hintFunctions); err != nil { - if !force { + if solution, err = spr.Solve(fullWitness, opt); err != nil { + if !opt.Force { return nil, err } else { // we need to fill solution with random values diff --git a/internal/backend/bw6-761/cs/r1cs.go b/internal/backend/bw6-761/cs/r1cs.go index 669b5537d6..962e7f970e 100644 --- a/internal/backend/bw6-761/cs/r1cs.go +++ b/internal/backend/bw6-761/cs/r1cs.go @@ -21,12 +21,11 @@ import ( "fmt" "io" "math/big" - "os" "strings" "github.com/fxamacker/cbor/v2" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" @@ -40,8 +39,6 @@ import ( type R1CS struct { compiled.R1CS Coefficients []fr.Element // R1C coefficients indexes point here - loggerOut io.Writer - mHints map[int]int // correspondance between hint wire ID and hint data struct } // NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values @@ -49,14 +46,11 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { r := R1CS{ R1CS: cs, Coefficients: make([]fr.Element, len(coefficients)), - loggerOut: os.Stdout, } for i := 0; i < len(coefficients); i++ { r.Coefficients[i].SetBigInt(&coefficients[i]) } - r.initHints() - return &r } @@ -65,10 +59,10 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { // a, b, c vectors: ab-c = hz // witness = [publicWires | secretWires] (without the ONE_WIRE !) // returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Function) ([]fr.Element, error) { +func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverOption) ([]fr.Element, error) { nbWires := cs.NbPublicVariables + cs.NbSecretVariables + cs.NbInternalVariables - solution, err := newSolution(nbWires, hintFunctions, cs.Coefficients) + solution, err := newSolution(nbWires, opt.HintFunctions, cs.Coefficients) if err != nil { return make([]fr.Element, nbWires), err } @@ -78,8 +72,8 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi } // compute the wires and the a, b, c polynomials - if len(a) != int(cs.NbConstraints) || len(b) != int(cs.NbConstraints) || len(c) != int(cs.NbConstraints) { - return solution.values, errors.New("invalid input size: len(a, b, c) == cs.NbConstraints") + if len(a) != len(cs.Constraints) || len(b) != len(cs.Constraints) || len(c) != len(cs.Constraints) { + return solution.values, errors.New("invalid input size: len(a, b, c) == len(Constraints)") } solution.solved[0] = true // ONE_WIRE @@ -95,16 +89,11 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // now that we know all inputs are set, defer log printing once all solution.values are computed // (or sooner, if a constraint is not satisfied) - defer solution.printLogs(cs.loggerOut, cs.Logs) + defer solution.printLogs(opt.LoggerOut, cs.Logs) // check if there is an inconsistant constraint var check fr.Element - // TODO @gbotrel clean this - // this variable is used to navigate in the debugInfoComputation slice. - // It is incremented by one each time a division happens for solving a constraint. - var debugInfoComputationOffset uint - // for each constraint // we are guaranteed that each R1C contains at most one unsolved wire // first we solve the unsolved wire (if any) @@ -112,11 +101,10 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // if a[i] * b[i] != c[i]; it means the constraint is not satisfied for i := 0; i < len(cs.Constraints); i++ { // solve the constraint, this will compute the missing wire of the gate - offset, err := cs.solveConstraint(cs.Constraints[i], &solution) - if err != nil { + if err := cs.solveConstraint(cs.Constraints[i], &solution); err != nil { + // TODO should return debug info, if any. return solution.values, err } - debugInfoComputationOffset += offset // compute values for the R1C (ie value * coeff) a[i], b[i], c[i] = cs.instantiateR1C(cs.Constraints[i], &solution) @@ -124,9 +112,11 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // ensure a[i] * b[i] == c[i] check.Mul(&a[i], &b[i]) if !check.Equal(&c[i]) { - debugInfo := cs.DebugInfoComputation[debugInfoComputationOffset] - debugInfoStr := solution.logValue(debugInfo) - return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + } + return solution.values, ErrUnsatisfiedConstraint } } @@ -140,23 +130,14 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) error { - a := make([]fr.Element, cs.NbConstraints) - b := make([]fr.Element, cs.NbConstraints) - c := make([]fr.Element, cs.NbConstraints) - _, err := cs.Solve(witness, a, b, c, hintFunctions) +func (cs *R1CS) IsSolved(witness []fr.Element, opt backend.ProverOption) error { + a := make([]fr.Element, len(cs.Constraints)) + b := make([]fr.Element, len(cs.Constraints)) + c := make([]fr.Element, len(cs.Constraints)) + _, err := cs.Solve(witness, a, b, c, opt) return err } -func (cs *R1CS) initHints() { - // we may do that sooner to save time in the solver, but we want the serialized data structures to be - // deterministic, hence avoid maps in there. - cs.mHints = make(map[int]int, len(cs.Hints)) - for i := 0; i < len(cs.Hints); i++ { - cs.mHints[cs.Hints[i].WireID] = i - } -} - // mulByCoeff sets res = res * t.Coeff func (cs *R1CS) mulByCoeff(res *fr.Element, t compiled.Term) { cID := t.CoeffID() @@ -203,10 +184,7 @@ func (cs *R1CS) instantiateR1C(r compiled.R1C, solution *solution) (a, b, c fr.E // It returns the 1 if the the position to solve is in the quadratic part (it // means that there is a division and serves to navigate in the log info for the // computational constraints), and 0 otherwise. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error) { - - // value to return: 1 if the wire to solve is in the quadratic term, 0 otherwise - var offset uint +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) error { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated @@ -226,9 +204,8 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error } // first we check if this is a hint wire - if hID, ok := cs.mHints[vID]; ok { - // TODO handle error - return solution.solveHint(cs.Hints[hID], vID) + if hint, ok := cs.MHints[vID]; ok { + return solution.solveWithHint(vID, hint) } if loc != 0 { @@ -241,19 +218,19 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error for _, t := range r.L { if err := processTerm(t, &a, 1); err != nil { - return 0, err + return err } } for _, t := range r.R { if err := processTerm(t, &b, 2); err != nil { - return 0, err + return err } } for _, t := range r.O { if err := processTerm(t, &c, 3); err != nil { - return 0, err + return err } } @@ -261,7 +238,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return 0, nil + return nil } // we compute the wire value and instantiate it @@ -276,14 +253,12 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error wire.Div(&c, &b). Sub(&wire, &a) cs.mulByCoeff(&wire, termToCompute) - offset = 1 } case 2: if !a.IsZero() { wire.Div(&c, &a). Sub(&wire, &b) cs.mulByCoeff(&wire, termToCompute) - offset = 1 } case 3: wire.Mul(&a, &b). @@ -293,7 +268,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (uint, error solution.set(vID, wire) - return offset, nil + return nil } // TODO @gbotrel clean logs and html @@ -309,15 +284,7 @@ func (cs *R1CS) ToHTML(w io.Writer) error { return err } - type data struct { - *R1CS - MHints map[int]int - } - d := data{ - cs, - cs.mHints, - } - return t.Execute(w, &d) + return t.Execute(w, cs) } func add(a, b int) int { @@ -328,10 +295,10 @@ func sub(a, b int) int { return a - b } -func toHTML(l compiled.LinearExpression, coeffs []fr.Element, mHints map[int]int) string { +func toHTML(l compiled.LinearExpression, coeffs []fr.Element, MHints map[int]int) string { var sbb strings.Builder for i := 0; i < len(l); i++ { - termToHTML(l[i], &sbb, coeffs, mHints, false) + termToHTML(l[i], &sbb, coeffs, MHints, false) if i+1 < len(l) { sbb.WriteString(" + ") } @@ -339,7 +306,7 @@ func toHTML(l compiled.LinearExpression, coeffs []fr.Element, mHints map[int]int return sbb.String() } -func termToHTML(t compiled.Term, sbb *strings.Builder, coeffs []fr.Element, mHints map[int]int, offset bool) { +func termToHTML(t compiled.Term, sbb *strings.Builder, coeffs []fr.Element, MHints map[int]int, offset bool) { tID := t.CoeffID() if tID == compiled.CoeffIdOne { // do nothing, just print the variable @@ -360,7 +327,7 @@ func termToHTML(t compiled.Term, sbb *strings.Builder, coeffs []fr.Element, mHin switch t.VariableVisibility() { case compiled.Internal: class = "internal" - if _, ok := mHints[vID]; ok { + if _, ok := MHints[vID]; ok { class = "hint" } case compiled.Public: @@ -396,24 +363,18 @@ func (cs *R1CS) FrSize() int { return fr.Limbs * 8 } -// SetLoggerOutput replace existing logger output with provided one -// default uses os.Stdout -// if nil is provided, logs are not printed -func (cs *R1CS) SetLoggerOutput(w io.Writer) { - cs.loggerOut = w -} - // WriteTo encodes R1CS into provided io.Writer using cbor func (cs *R1CS) WriteTo(w io.Writer) (int64, error) { _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - encoder := cbor.NewEncoder(&_w) - - // encode our object - if err := encoder.Encode(cs); err != nil { - return _w.N, err + enc, err := cbor.CoreDetEncOptions().EncMode() + if err != nil { + return 0, err } + encoder := enc.NewEncoder(&_w) - return _w.N, nil + // encode our object + err = encoder.Encode(cs) + return _w.N, err } // ReadFrom attempts to decode R1CS from io.Reader using cbor @@ -427,8 +388,5 @@ func (cs *R1CS) ReadFrom(r io.Reader) (int64, error) { return int64(decoder.NumBytesRead()), err } - // init the hint map - cs.initHints() - return int64(decoder.NumBytesRead()), nil } diff --git a/internal/backend/bw6-761/cs/r1cs_sparse.go b/internal/backend/bw6-761/cs/r1cs_sparse.go index dc334eadae..5adf8da3b8 100644 --- a/internal/backend/bw6-761/cs/r1cs_sparse.go +++ b/internal/backend/bw6-761/cs/r1cs_sparse.go @@ -26,7 +26,7 @@ import ( "strings" "text/template" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" @@ -37,9 +37,8 @@ import ( type SparseR1CS struct { compiled.SparseR1CS - // Coefficients in the constraints - Coefficients []fr.Element // list of unique coefficients. - mHints map[int]int // correspondance between hint wire ID and hint data struct + Coefficients []fr.Element // coefficients in the constraints + loggerOut io.Writer } // NewSparseR1CS returns a new SparseR1CS and sets r1cs.Coefficient (fr.Element) from provided big.Int values @@ -47,13 +46,12 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS cs := SparseR1CS{ SparseR1CS: ccs, Coefficients: make([]fr.Element, len(coefficients)), + loggerOut: os.Stdout, } for i := 0; i < len(coefficients); i++ { cs.Coefficients[i].SetBigInt(&coefficients[i]) } - cs.initHints() - return &cs } @@ -61,7 +59,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS // solution.values = [publicInputs | secretInputs | internalVariables ] // witness: contains the input variables // it returns the full slice of wires -func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) (values []fr.Element, err error) { +func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverOption) ([]fr.Element, error) { // set the slices holding the solution.values and monitoring which variables have been solved nbVariables := cs.NbInternalVariables + cs.NbSecretVariables + cs.NbPublicVariables @@ -78,7 +76,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) } // keep track of wire that have a value - solution, err := newSolution(nbVariables, hintFunctions, cs.Coefficients) + solution, err := newSolution(nbVariables, opt.HintFunctions, cs.Coefficients) if err != nil { return solution.values, err } @@ -94,8 +92,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) solution.nbSolved += len(witness) // defer log printing once all solution.values are computed - // TODO @gbotrel replace stdout by writer set by user, same as in R1CS - defer solution.printLogs(os.Stdout, cs.Logs) + defer solution.printLogs(opt.LoggerOut, cs.Logs) // batch invert the coefficients to avoid many divisions in the solver coefficientsNegInv := fr.BatchInvert(cs.Coefficients) @@ -109,7 +106,11 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, hintFunctions []hint.Function) return solution.values, fmt.Errorf("constraint %d: %w", i, err) } if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil { - return solution.values, fmt.Errorf("constraint %d: %w", i, err) + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + } + return solution.values, ErrUnsatisfiedConstraint } } @@ -131,8 +132,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.L.CoeffID() != 0 || c.M[0].CoeffID() != 0) && !solution.solved[lID] { // check if it's a hint - if hID, ok := cs.mHints[lID]; ok { - if err := solution.solveHint(cs.Hints[hID], lID); err != nil { + if hint, ok := cs.MHints[lID]; ok { + if err := solution.solveWithHint(lID, hint); err != nil { return -1, err } } else { @@ -143,8 +144,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.R.CoeffID() != 0 || c.M[1].CoeffID() != 0) && !solution.solved[rID] { // check if it's a hint - if hID, ok := cs.mHints[rID]; ok { - if err := solution.solveHint(cs.Hints[hID], rID); err != nil { + if hint, ok := cs.MHints[rID]; ok { + if err := solution.solveWithHint(rID, hint); err != nil { return -1, err } } else { @@ -154,8 +155,8 @@ func (cs *SparseR1CS) computeHints(c compiled.SparseR1C, solution *solution) (in if (c.O.CoeffID() != 0) && !solution.solved[oID] { // check if it's a hint - if hID, ok := cs.mHints[oID]; ok { - if err := solution.solveHint(cs.Hints[hID], oID); err != nil { + if hint, ok := cs.MHints[oID]; ok { + if err := solution.solveWithHint(oID, hint); err != nil { return -1, err } } else { @@ -223,8 +224,8 @@ func (cs *SparseR1CS) solveConstraint(c compiled.SparseR1C, solution *solution, // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps r1cs.Solve() and allocates r1cs.Solve() inputs -func (cs *SparseR1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) error { - _, err := cs.Solve(witness, hintFunctions) +func (cs *SparseR1CS) IsSolved(witness []fr.Element, opt backend.ProverOption) error { + _, err := cs.Solve(witness, opt) return err } @@ -265,21 +266,12 @@ func (cs *SparseR1CS) ToHTML(w io.Writer) error { return err } - type data struct { - *SparseR1CS - MHints map[int]int - } - d := data{ - cs, - cs.mHints, - } - - return t.Execute(w, &d) + return t.Execute(w, cs) } -func toHTMLTerm(t compiled.Term, coeffs []fr.Element, mHints map[int]int) string { +func toHTMLTerm(t compiled.Term, coeffs []fr.Element, MHints map[int]int) string { var sbb strings.Builder - termToHTML(t, &sbb, coeffs, mHints, true) + termToHTML(t, &sbb, coeffs, MHints, true) return sbb.String() } @@ -295,15 +287,6 @@ func toHTMLCoeff(cID int, coeffs []fr.Element) string { return sbb.String() } -func (cs *SparseR1CS) initHints() { - // we may do that sooner to save time in the solver, but we want the serialized data structures to be - // deterministic, hence avoid maps in there. - cs.mHints = make(map[int]int, len(cs.Hints)) - for i := 0; i < len(cs.Hints); i++ { - cs.mHints[cs.Hints[i].WireID] = i - } -} - // FrSize return fr.Limbs * 8, size in byte of a fr element func (cs *SparseR1CS) FrSize() int { return fr.Limbs * 8 @@ -322,10 +305,14 @@ func (cs *SparseR1CS) CurveID() ecc.ID { // WriteTo encodes SparseR1CS into provided io.Writer using cbor func (cs *SparseR1CS) WriteTo(w io.Writer) (int64, error) { _w := ioutils.WriterCounter{W: w} // wraps writer to count the bytes written - encoder := cbor.NewEncoder(&_w) + enc, err := cbor.CoreDetEncOptions().EncMode() + if err != nil { + return 0, err + } + encoder := enc.NewEncoder(&_w) // encode our object - err := encoder.Encode(cs) + err = encoder.Encode(cs) return _w.N, err } @@ -337,6 +324,12 @@ func (cs *SparseR1CS) ReadFrom(r io.Reader) (int64, error) { } decoder := dm.NewDecoder(r) err = decoder.Decode(cs) - cs.initHints() return int64(decoder.NumBytesRead()), err } + +// SetLoggerOutput replace existing logger output with provided one +// default uses os.Stdout +// if nil is provided, logs are not printed +func (cs *SparseR1CS) SetLoggerOutput(w io.Writer) { + cs.loggerOut = w +} diff --git a/internal/backend/bw6-761/cs/r1cs_test.go b/internal/backend/bw6-761/cs/r1cs_test.go index 5289783da9..7601610373 100644 --- a/internal/backend/bw6-761/cs/r1cs_test.go +++ b/internal/backend/bw6-761/cs/r1cs_test.go @@ -52,9 +52,6 @@ func TestSerialization(t *testing.T) { t.Fatal(err) } - // no need to serialize. - r1cs.SetLoggerOutput(nil) - r1cs2.SetLoggerOutput(nil) { buffer.Reset() t.Log(name) diff --git a/internal/backend/bw6-761/cs/cs.go b/internal/backend/bw6-761/cs/solution.go similarity index 75% rename from internal/backend/bw6-761/cs/cs.go rename to internal/backend/bw6-761/cs/solution.go index 6ea550d88f..deceb66fb9 100644 --- a/internal/backend/bw6-761/cs/cs.go +++ b/internal/backend/bw6-761/cs/solution.go @@ -105,7 +105,7 @@ func (s *solution) computeTerm(t compiled.Term) fr.Element { } // solveHint compute solution.values[vID] using provided solver hint -func (s *solution) solveHint(h compiled.Hint, vID int) error { +func (s *solution) solveWithHint(vID int, h compiled.Hint) error { // compute values for all inputs. inputs := make([]fr.Element, len(h.Inputs)) @@ -149,12 +149,65 @@ func (s *solution) printLogs(w io.Writer, logs []compiled.LogEntry) { } } +const unsolvedVariable = "" + func (s *solution) logValue(log compiled.LogEntry) string { var toResolve []interface{} + var ( + isEval bool + eval fr.Element + missingValue bool + ) for j := 0; j < len(log.ToResolve); j++ { - vID := log.ToResolve[j] + if log.ToResolve[j] == compiled.TermDelimitor { + // this is a special case where we want to evaluate the following terms until the next delimitor. + if !isEval { + isEval = true + missingValue = false + eval.SetZero() + continue + } + isEval = false + if missingValue { + toResolve = append(toResolve, unsolvedVariable) + } else { + // we have to append our accumulator + toResolve = append(toResolve, eval.String()) + } + continue + } + cID, vID, visibility := log.ToResolve[j].Unpack() + + if isEval { + // we are evaluating + if visibility == compiled.Virtual { + // just add the constant + eval.Add(&eval, &s.coefficients[cID]) + continue + } + if !s.solved[vID] { + missingValue = true + continue + } + tv := s.computeTerm(log.ToResolve[j]) + eval.Add(&eval, &tv) + continue + } + + if visibility == compiled.Virtual { + // it's just a constant + if cID == compiled.CoeffIdMinusOne { + toResolve = append(toResolve, "-1") + } else { + toResolve = append(toResolve, s.coefficients[cID].String()) + } + continue + } + if !(cID == compiled.CoeffIdMinusOne || cID == compiled.CoeffIdOne) { + toResolve = append(toResolve, s.coefficients[cID].String()) + } if !s.solved[vID] { - toResolve = append(toResolve, "???") + toResolve = append(toResolve, unsolvedVariable) } else { toResolve = append(toResolve, s.values[vID].String()) } diff --git a/internal/backend/bw6-761/groth16/groth16_test.go b/internal/backend/bw6-761/groth16/groth16_test.go index ca19ed0afe..f84636e32b 100644 --- a/internal/backend/bw6-761/groth16/groth16_test.go +++ b/internal/backend/bw6-761/groth16/groth16_test.go @@ -130,7 +130,7 @@ func BenchmarkProver(b *testing.B) { b.ResetTimer() b.Run("prover", func(b *testing.B) { for i := 0; i < b.N; i++ { - _, _ = bw6_761groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + _, _ = bw6_761groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) } }) } @@ -151,7 +151,7 @@ func BenchmarkVerifier(b *testing.B) { var pk bw6_761groth16.ProvingKey var vk bw6_761groth16.VerifyingKey bw6_761groth16.Setup(r1cs.(*cs.R1CS), &pk, &vk) - proof, err := bw6_761groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + proof, err := bw6_761groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } @@ -175,7 +175,7 @@ func BenchmarkSerialization(b *testing.B) { var pk bw6_761groth16.ProvingKey var vk bw6_761groth16.VerifyingKey bw6_761groth16.Setup(r1cs.(*cs.R1CS), &pk, &vk) - proof, err := bw6_761groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, nil, false) + proof, err := bw6_761groth16.Prove(r1cs.(*cs.R1CS), &pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } diff --git a/internal/backend/bw6-761/groth16/prove.go b/internal/backend/bw6-761/groth16/prove.go index 6d81352e53..c3995877ca 100644 --- a/internal/backend/bw6-761/groth16/prove.go +++ b/internal/backend/bw6-761/groth16/prove.go @@ -27,7 +27,7 @@ import ( "fmt" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" bw6_761witness "github.com/consensys/gnark/internal/backend/bw6-761/witness" "github.com/consensys/gnark/internal/utils" "math/big" @@ -53,21 +53,19 @@ func (proof *Proof) CurveID() ecc.ID { } // Prove generates the proof of knoweldge of a r1cs with full witness (secret + public part). -// if force flag is set, Prove ignores R1CS solving error (ie invalid witness) and executes -// the FFTs and MultiExponentiations to compute an (invalid) Proof object -func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bw6_761witness.Witness, hintFunctions []hint.Function, force bool) (*Proof, error) { +func Prove(r1cs *cs.R1CS, pk *ProvingKey, witness bw6_761witness.Witness, opt backend.ProverOption) (*Proof, error) { if len(witness) != int(r1cs.NbPublicVariables-1+r1cs.NbSecretVariables) { return nil, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public - ONE_WIRE) + %d (secret)", len(witness), int(r1cs.NbPublicVariables-1+r1cs.NbSecretVariables), r1cs.NbPublicVariables, r1cs.NbSecretVariables) } // solve the R1CS and compute the a, b, c vectors - a := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) - b := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) - c := make([]fr.Element, r1cs.NbConstraints, pk.Domain.Cardinality) + a := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) + b := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) + c := make([]fr.Element, len(r1cs.Constraints), pk.Domain.Cardinality) var wireValues []fr.Element var err error - if wireValues, err = r1cs.Solve(witness, a, b, c, hintFunctions); err != nil { - if !force { + if wireValues, err = r1cs.Solve(witness, a, b, c, opt); err != nil { + if !opt.Force { return nil, err } else { // we need to fill wireValues with random values else multi exps don't do much diff --git a/internal/backend/bw6-761/groth16/setup.go b/internal/backend/bw6-761/groth16/setup.go index 68745e3335..b9cacb8339 100644 --- a/internal/backend/bw6-761/groth16/setup.go +++ b/internal/backend/bw6-761/groth16/setup.go @@ -95,7 +95,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables // Setting group for fft - domain := fft.NewDomain(uint64(r1cs.NbConstraints), 1, true) + domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -412,7 +412,7 @@ func sampleToxicWaste() (toxicWaste, error) { func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { // get R1CS nb constraints, wires and public/private inputs nbWires := r1cs.NbInternalVariables + r1cs.NbPublicVariables + r1cs.NbSecretVariables - nbConstraints := r1cs.NbConstraints + nbConstraints := len(r1cs.Constraints) // Setting group for fft domain := fft.NewDomain(uint64(nbConstraints), 1, true) diff --git a/internal/backend/bw6-761/plonk/plonk_test.go b/internal/backend/bw6-761/plonk/plonk_test.go index 2089e9e37c..b9728c4d75 100644 --- a/internal/backend/bw6-761/plonk/plonk_test.go +++ b/internal/backend/bw6-761/plonk/plonk_test.go @@ -135,7 +135,7 @@ func BenchmarkProver(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, err = bw6_761plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false) + _, err = bw6_761plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { b.Fatal(err) } @@ -160,7 +160,7 @@ func BenchmarkVerifier(b *testing.B) { b.Fatal(err) } - proof, err := bw6_761plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false) + proof, err := bw6_761plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { panic(err) } @@ -184,7 +184,7 @@ func BenchmarkSerialization(b *testing.B) { b.Fatal(err) } - proof, err := bw6_761plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, nil, false) + proof, err := bw6_761plonk.Prove(ccs.(*cs.SparseR1CS), pk, fullWitness, backend.ProverOption{}) if err != nil { b.Fatal(err) } diff --git a/internal/backend/bw6-761/plonk/prove.go b/internal/backend/bw6-761/plonk/prove.go index ad79236bf8..0a52b7f156 100644 --- a/internal/backend/bw6-761/plonk/prove.go +++ b/internal/backend/bw6-761/plonk/prove.go @@ -38,7 +38,7 @@ import ( "github.com/consensys/gnark/internal/backend/bw6-761/cs" "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/utils" ) @@ -60,7 +60,7 @@ type Proof struct { } // Prove from the public data -func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witness, hintFunctions []hint.Function, force bool) (*Proof, error) { +func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witness, opt backend.ProverOption) (*Proof, error) { // pick a hash function that will be used to derive the challenges hFunc := sha256.New() @@ -74,8 +74,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes // compute the constraint system solution var solution []fr.Element var err error - if solution, err = spr.Solve(fullWitness, hintFunctions); err != nil { - if !force { + if solution, err = spr.Solve(fullWitness, opt); err != nil { + if !opt.Force { return nil, err } else { // we need to fill solution with random values diff --git a/internal/backend/compiled/cs.go b/internal/backend/compiled/cs.go new file mode 100644 index 0000000000..a8e6c17e6a --- /dev/null +++ b/internal/backend/compiled/cs.go @@ -0,0 +1,74 @@ +package compiled + +import ( + "io" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/hint" +) + +// CS contains common element between R1CS and CS +type CS struct { + // number of wires + NbInternalVariables int + NbPublicVariables int + NbSecretVariables int + + // logs (added with cs.Println, resolved when solver sets a value to a wire) + Logs []LogEntry + + // debug info contains stack trace (including line number) of a call to a cs.API that + // results in an unsolved constraint + DebugInfo []LogEntry + + // maps wire id to hint + // a wire may point to at most one hint + MHints map[int]Hint + + // maps constraint id to debugInfo id + // several constraints may point to the same debug info + MDebug map[int]int +} + +// Visibility encodes a Variable (or wire) visibility +// Possible values are Unset, Internal, Secret or Public +type Visibility uint8 + +const ( + Unset Visibility = iota + Internal + Secret + Public + Virtual +) + +// Hint represents a solver hint +// it enables the solver to compute a Wire with a function provided at solving time +// using pre-defined inputs +type Hint struct { + ID hint.ID // hint function id + Inputs []LinearExpression // terms to inject in the hint function +} + +// GetNbVariables return number of internal, secret and public variables +func (cs *CS) GetNbVariables() (internal, secret, public int) { + return cs.NbInternalVariables, cs.NbSecretVariables, cs.NbPublicVariables +} + +// FrSize panics +func (cs *CS) FrSize() int { panic("not implemented") } + +// GetNbCoefficients panics +func (cs *CS) GetNbCoefficients() int { panic("not implemented") } + +// CurveID returns ecc.UNKNOWN +func (cs *CS) CurveID() ecc.ID { return ecc.UNKNOWN } + +// WriteTo panics +func (cs *CS) WriteTo(w io.Writer) (n int64, err error) { panic("not implemented") } + +// ReadFrom panics +func (cs *CS) ReadFrom(r io.Reader) (n int64, err error) { panic("not implemented") } + +// ToHTML panics +func (cs *CS) ToHTML(w io.Writer) error { panic("not implemtened") } diff --git a/internal/backend/compiled/log.go b/internal/backend/compiled/log.go new file mode 100644 index 0000000000..e65a84d419 --- /dev/null +++ b/internal/backend/compiled/log.go @@ -0,0 +1,87 @@ +package compiled + +import ( + "runtime" + "strconv" + "strings" +) + +// LogEntry is used as a shared data structure between the frontend and the backend +// to represent string values (in logs or debug info) where a value is not known at compile time +// (which is the case for variables that need to be resolved in the R1CS) +type LogEntry struct { + Format string + ToResolve []Term +} + +func (l *LogEntry) WriteLinearExpression(le LinearExpression, sbb *strings.Builder) { + sbb.Grow(len(le) * len(" + (xx + xxxxxxxxxxxx")) + + for i := 0; i < len(le); i++ { + if i > 0 { + sbb.WriteString(" + ") + } + l.WriteTerm(le[i], sbb) + } +} + +func (l *LogEntry) WriteTerm(t Term, sbb *strings.Builder) { + // virtual == only a coeff, we discard the wire + if t.VariableVisibility() == Public && t.VariableID() == 0 { + sbb.WriteString("%s") + t.SetVariableVisibility(Virtual) + l.ToResolve = append(l.ToResolve, t) + return + } + + cID := t.CoeffID() + if cID == CoeffIdMinusOne { + sbb.WriteString("-%s") + } else if cID == CoeffIdOne { + sbb.WriteString("%s") + } else { + sbb.WriteString("%s*%s") + } + + l.ToResolve = append(l.ToResolve, t) +} + +func (l *LogEntry) WriteStack(sbb *strings.Builder) { + // derived from: https://golang.org/pkg/runtime/#example_Frames + // we stop when func name == Define as it is where the gnark circuit code should start + + // Ask runtime.Callers for up to 10 pcs + pc := make([]uintptr, 10) + n := runtime.Callers(3, pc) + if n == 0 { + // No pcs available. Stop now. + // This can happen if the first argument to runtime.Callers is large. + return + } + pc = pc[:n] // pass only valid pcs to runtime.CallersFrames + frames := runtime.CallersFrames(pc) + // Loop to get frames. + // A fixed number of pcs can expand to an indefinite number of Frames. + for { + frame, more := frames.Next() + fe := strings.Split(frame.Function, "/") + function := fe[len(fe)-1] + if strings.Contains(function, "frontend.(*ConstraintSystem)") { + continue + } + + sbb.WriteString(function) + sbb.WriteByte('\n') + sbb.WriteByte('\t') + sbb.WriteString(frame.File) + sbb.WriteByte(':') + sbb.WriteString(strconv.Itoa(frame.Line)) + sbb.WriteByte('\n') + if !more { + break + } + if strings.HasSuffix(function, "Define") { + break + } + } +} diff --git a/internal/backend/compiled/r1c.go b/internal/backend/compiled/r1c.go index 1e08ecec36..783b80e9a8 100644 --- a/internal/backend/compiled/r1c.go +++ b/internal/backend/compiled/r1c.go @@ -16,12 +16,14 @@ package compiled import ( "math/big" - "strconv" "strings" - - "github.com/consensys/gnark/backend/hint" ) +// R1C used to compute the wires +type R1C struct { + L, R, O LinearExpression +} + // LinearExpression represent a linear expression of variables type LinearExpression []Term @@ -52,42 +54,6 @@ func (l LinearExpression) Less(i, j int) bool { return iVis > jVis } -// R1C used to compute the wires -type R1C struct { - L LinearExpression - R LinearExpression - O LinearExpression -} - -// LogEntry is used as a shared data structure between the frontend and the backend -// to represent string values (in logs or debug info) where a value is not known at compile time -// (which is the case for variables that need to be resolved in the R1CS) -type LogEntry struct { - Format string - ToResolve []int -} - -// Visibility encodes a Variable (or wire) visibility -// Possible values are Unset, Internal, Secret or Public -type Visibility uint8 - -const ( - Unset Visibility = iota - Internal - Secret - Public - Virtual -) - -// Hint represents a solver hint -// it enables the solver to compute a Wire with a function provided at solving time -// using pre-defined inputs -type Hint struct { - WireID int // resulting wire ID to compute - ID hint.ID // hint function id - Inputs []LinearExpression // terms to inject in the hint function -} - func (r1c *R1C) String(coeffs []big.Int) string { var sbb strings.Builder sbb.WriteString("L[") @@ -109,23 +75,3 @@ func (l LinearExpression) string(sbb *strings.Builder, coeffs []big.Int) { } } } - -func (t Term) string(sbb *strings.Builder, coeffs []big.Int) { - sbb.WriteString(coeffs[t.CoeffID()].String()) - sbb.WriteString("*") - switch t.VariableVisibility() { - case Internal: - sbb.WriteString("i") - case Public: - sbb.WriteString("p") - case Secret: - sbb.WriteString("s") - case Virtual: - sbb.WriteString("v") - case Unset: - sbb.WriteString("u") - default: - panic("not implemented") - } - sbb.WriteString(strconv.Itoa(t.VariableID())) -} diff --git a/internal/backend/compiled/r1cs.go b/internal/backend/compiled/r1cs.go index 4410d75636..8831f6da01 100644 --- a/internal/backend/compiled/r1cs.go +++ b/internal/backend/compiled/r1cs.go @@ -14,75 +14,13 @@ package compiled -import ( - "io" - - "github.com/consensys/gnark-crypto/ecc" -) - -// R1CS decsribes a set of R1CS constraint +// R1CS decsribes a set of R1C constraint type R1CS struct { - // Wires - NbInternalVariables int - NbPublicVariables int // includes ONE wire - NbSecretVariables int - Logs []LogEntry - DebugInfoComputation []LogEntry - - // Constraints - NbConstraints int // total number of constraints - Constraints []R1C - - // Hints - Hints []Hint + CS + Constraints []R1C } // GetNbConstraints returns the number of constraints func (r1cs *R1CS) GetNbConstraints() int { - return r1cs.NbConstraints -} - -// GetNbVariables return number of internal, secret and public variables -func (r1cs *R1CS) GetNbVariables() (internal, secret, public int) { - internal = r1cs.NbInternalVariables - secret = r1cs.NbSecretVariables - public = r1cs.NbPublicVariables - return -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (r1cs *R1CS) GetNbCoefficients() int { - panic("not implemented") -} - -// CurveID returns ecc.UNKNOWN as this is a untyped R1CS using big.Int -func (r1cs *R1CS) CurveID() ecc.ID { - return ecc.UNKNOWN -} - -// FrSize panics -func (r1cs *R1CS) FrSize() int { - panic("not implemented") -} - -// WriteTo panics (can't serialize untyped R1CS) -func (r1cs *R1CS) WriteTo(w io.Writer) (n int64, err error) { - panic("not implemented") -} - -// ReadFrom panics (can't deserialize untyped R1CS) -func (r1cs *R1CS) ReadFrom(r io.Reader) (n int64, err error) { - panic("not implemented") -} - -// SetLoggerOutput replace existing logger output with provided one -// default uses os.Stdout -// if nil is provided, logs are not printed -func (r1cs *R1CS) SetLoggerOutput(w io.Writer) { - panic("not implemented") -} - -// ToHTML returns an HTML human-readable representation of the constraint system -func (r1cs *R1CS) ToHTML(w io.Writer) error { - panic("not implemented") + return len(r1cs.Constraints) } diff --git a/internal/backend/compiled/r1cs_sparse.go b/internal/backend/compiled/r1cs_sparse.go index 0ade1f341e..bcd59d5994 100644 --- a/internal/backend/compiled/r1cs_sparse.go +++ b/internal/backend/compiled/r1cs_sparse.go @@ -14,82 +14,13 @@ package compiled -import ( - "io" - - "github.com/consensys/gnark-crypto/ecc" -) - -// SparseR1CS represents a Plonk like circuit +// R1CS decsribes a set of SparseR1C constraint type SparseR1CS struct { - - // Variables [publicVariables| secretVariables | internalVariables ] - NbInternalVariables int - NbPublicVariables int - NbSecretVariables int - - // Constraints contains an ordered list of SparseR1C - // the solver will iterate through them and is guaranteed that there will be at most one - // unsolved wire per constraint + CS Constraints []SparseR1C - - // Logs (e.g. variables that have been printed using cs.Println) - Logs []LogEntry - - // Hints - Hints []Hint -} - -// GetNbVariables return number of internal, secret and public variables -func (cs *SparseR1CS) GetNbVariables() (internal, secret, public int) { - internal = cs.NbInternalVariables - secret = cs.NbSecretVariables - public = cs.NbPublicVariables - return } // GetNbConstraints returns the number of constraints func (cs *SparseR1CS) GetNbConstraints() int { return len(cs.Constraints) } - -// GetNbWires returns the number of wires (internal) -func (cs *SparseR1CS) GetNbWires() int { - return cs.NbInternalVariables -} - -// FrSize panics -func (cs *SparseR1CS) FrSize() int { - panic("not implemented") -} - -// GetNbCoefficients return the number of unique coefficients needed in the R1CS -func (cs *SparseR1CS) GetNbCoefficients() int { - panic("not implemented") -} - -// CurveID returns ecc.UNKNOWN as this is a untyped R1CS using big.Int -func (cs *SparseR1CS) CurveID() ecc.ID { - return ecc.UNKNOWN -} - -// WriteTo panics -func (cs *SparseR1CS) WriteTo(w io.Writer) (n int64, err error) { - panic("not implemented") -} - -// ReadFrom panics -func (cs *SparseR1CS) ReadFrom(r io.Reader) (n int64, err error) { - panic("not implemented") -} - -// SetLoggerOutput replace existing logger output with provided one -// default uses os.Stdout -// if nil is provided, logs are not printed -func (cs *SparseR1CS) SetLoggerOutput(w io.Writer) { - panic("not implemented") -} - -func (cs *SparseR1CS) ToHTML(w io.Writer) error { - panic("not implemtened") -} diff --git a/internal/backend/compiled/term.go b/internal/backend/compiled/term.go index 60d420b0e2..6800544d0b 100644 --- a/internal/backend/compiled/term.go +++ b/internal/backend/compiled/term.go @@ -14,6 +14,12 @@ package compiled +import ( + "math/big" + "strconv" + "strings" +) + // Term lightweight version of a term, no pointers // first 4 bits are reserved // next 30 bits represented the coefficient idx (in r1cs.Coefficients) by which the wire is multiplied @@ -40,20 +46,27 @@ const ( const ( nbBitsVariableID = 29 nbBitsCoeffID = 30 - nbBitsFutureUse = 2 + nbBitsDelimitor = 1 + nbBitsFutureUse = 1 nbBitsVariableVisibility = 3 ) +// TermDelimitor is reserved for internal use +// the constraint solver will evaluate the sum of all terms appearing between two TermDelimitor +const TermDelimitor Term = Term(maskDelimitor) + const ( shiftVariableID = 0 shiftCoeffID = nbBitsVariableID - shiftFutureUse = shiftCoeffID + nbBitsCoeffID + shiftDelimitor = shiftCoeffID + nbBitsCoeffID + shiftFutureUse = shiftDelimitor + nbBitsDelimitor shiftVariableVisibility = shiftFutureUse + nbBitsFutureUse ) const ( maskVariableID = uint64((1 << nbBitsVariableID) - 1) maskCoeffID = uint64((1<> shiftCoeffID) } + +func (t Term) string(sbb *strings.Builder, coeffs []big.Int) { + sbb.WriteString(coeffs[t.CoeffID()].String()) + sbb.WriteString("*") + switch t.VariableVisibility() { + case Internal: + sbb.WriteString("i") + case Public: + sbb.WriteString("p") + case Secret: + sbb.WriteString("s") + case Virtual: + sbb.WriteString("v") + case Unset: + sbb.WriteString("u") + default: + panic("not implemented") + } + sbb.WriteString(strconv.Itoa(t.VariableID())) +} diff --git a/internal/generator/backend/main.go b/internal/generator/backend/main.go index 5b406b972c..90e3562a8d 100644 --- a/internal/generator/backend/main.go +++ b/internal/generator/backend/main.go @@ -76,7 +76,7 @@ func main() { entries := []bavard.Entry{ {File: filepath.Join(backendCSDir, "r1cs.go"), Templates: []string{"r1cs.go.tmpl", importCurve}}, {File: filepath.Join(backendCSDir, "r1cs_sparse.go"), Templates: []string{"r1cs.sparse.go.tmpl", importCurve}}, - {File: filepath.Join(backendCSDir, "cs.go"), Templates: []string{"cs.go.tmpl", importCurve}}, + {File: filepath.Join(backendCSDir, "solution.go"), Templates: []string{"solution.go.tmpl", importCurve}}, {File: filepath.Join(backendCSDir, "hints.go"), Templates: []string{"hints.go.tmpl", importCurve}}, } if err := bgen.Generate(d, "cs", "./template/representations/", entries...); err != nil { diff --git a/internal/generator/backend/template/representations/r1cs.go.tmpl b/internal/generator/backend/template/representations/r1cs.go.tmpl index fa61f08062..16909315e6 100644 --- a/internal/generator/backend/template/representations/r1cs.go.tmpl +++ b/internal/generator/backend/template/representations/r1cs.go.tmpl @@ -2,7 +2,6 @@ import ( "errors" "fmt" "io" - "os" "math/big" "strings" @@ -10,7 +9,7 @@ import ( "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark/internal/backend/compiled" - "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark-crypto/ecc" "text/template" @@ -23,8 +22,6 @@ import ( type R1CS struct { compiled.R1CS Coefficients []fr.Element // R1C coefficients indexes point here - loggerOut io.Writer - mHints map[int]int // correspondance between hint wire ID and hint data struct } // NewR1CS returns a new R1CS and sets cs.Coefficient (fr.Element) from provided big.Int values @@ -32,14 +29,11 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { r := R1CS{ R1CS: cs, Coefficients: make([]fr.Element, len(coefficients)), - loggerOut: os.Stdout, } for i := 0; i < len(coefficients); i++ { r.Coefficients[i].SetBigInt(&coefficients[i]) } - r.initHints() - return &r } @@ -49,10 +43,10 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { // a, b, c vectors: ab-c = hz // witness = [publicWires | secretWires] (without the ONE_WIRE !) // returns [publicWires | secretWires | internalWires ] -func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Function) ([]fr.Element, error) { +func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverOption) ([]fr.Element, error) { nbWires := cs.NbPublicVariables + cs.NbSecretVariables + cs.NbInternalVariables - solution, err := newSolution(nbWires, hintFunctions, cs.Coefficients) + solution, err := newSolution(nbWires, opt.HintFunctions, cs.Coefficients) if err != nil { return make([]fr.Element, nbWires), err } @@ -64,8 +58,8 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // compute the wires and the a, b, c polynomials - if len(a) != int(cs.NbConstraints) || len(b) != int(cs.NbConstraints) || len(c) != int(cs.NbConstraints){ - return solution.values, errors.New("invalid input size: len(a, b, c) == cs.NbConstraints") + if len(a) != len(cs.Constraints) || len(b) != len(cs.Constraints) || len(c) != len(cs.Constraints){ + return solution.values, errors.New("invalid input size: len(a, b, c) == len(Constraints)") } @@ -83,17 +77,12 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // now that we know all inputs are set, defer log printing once all solution.values are computed // (or sooner, if a constraint is not satisfied) - defer solution.printLogs(cs.loggerOut, cs.Logs) + defer solution.printLogs(opt.LoggerOut, cs.Logs) // check if there is an inconsistant constraint var check fr.Element - // TODO @gbotrel clean this - // this variable is used to navigate in the debugInfoComputation slice. - // It is incremented by one each time a division happens for solving a constraint. - var debugInfoComputationOffset uint - // for each constraint // we are guaranteed that each R1C contains at most one unsolved wire @@ -102,11 +91,10 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // if a[i] * b[i] != c[i]; it means the constraint is not satisfied for i := 0; i < len(cs.Constraints); i++ { // solve the constraint, this will compute the missing wire of the gate - offset, err := cs.solveConstraint(cs.Constraints[i], &solution) - if err != nil { + if err := cs.solveConstraint(cs.Constraints[i], &solution); err != nil { + // TODO should return debug info, if any. return solution.values, err } - debugInfoComputationOffset += offset // compute values for the R1C (ie value * coeff) a[i], b[i], c[i] = cs.instantiateR1C(cs.Constraints[i], &solution) @@ -114,9 +102,11 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // ensure a[i] * b[i] == c[i] check.Mul(&a[i], &b[i]) if !check.Equal(&c[i]) { - debugInfo := cs.DebugInfoComputation[debugInfoComputationOffset] - debugInfoStr := solution.logValue(debugInfo) - return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + return solution.values, fmt.Errorf("%w: %s", ErrUnsatisfiedConstraint, debugInfoStr) + } + return solution.values, ErrUnsatisfiedConstraint } } @@ -131,27 +121,16 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, hintFunctions []hint.Functi // IsSolved returns nil if given witness solves the R1CS and error otherwise // this method wraps cs.Solve() and allocates cs.Solve() inputs -func (cs *R1CS) IsSolved(witness []fr.Element, hintFunctions []hint.Function) error { - a := make([]fr.Element, cs.NbConstraints) - b := make([]fr.Element, cs.NbConstraints) - c := make([]fr.Element, cs.NbConstraints) - _, err := cs.Solve(witness, a, b, c, hintFunctions) +func (cs *R1CS) IsSolved(witness []fr.Element, opt backend.ProverOption) error { + a := make([]fr.Element, len(cs.Constraints)) + b := make([]fr.Element, len(cs.Constraints)) + c := make([]fr.Element, len(cs.Constraints)) + _, err := cs.Solve(witness, a, b, c, opt) return err } -func (cs *R1CS) initHints() { - // we may do that sooner to save time in the solver, but we want the serialized data structures to be - // deterministic, hence avoid maps in there. - cs.mHints = make(map[int]int, len(cs.Hints)) - for i:=0; i