diff --git a/frontend/cs/scs/api.go b/frontend/cs/scs/api.go index b62ac94fb6..3994514c40 100644 --- a/frontend/cs/scs/api.go +++ b/frontend/cs/scs/api.go @@ -193,6 +193,8 @@ func (system *scs) FromBinary(b ...frontend.Variable) frontend.Variable { // Xor returns a ^ b // a and b must be 0 or 1 func (system *scs) Xor(a, b frontend.Variable) frontend.Variable { + system.AssertIsBoolean(a) + system.AssertIsBoolean(b) _a, aConstant := system.ConstantValue(a) _b, bConstant := system.ConstantValue(b) @@ -200,6 +202,7 @@ func (system *scs) Xor(a, b frontend.Variable) frontend.Variable { _a.Xor(_a, _b) return _a } + res := system.newInternalVariable() if aConstant { a, b = b, a @@ -224,6 +227,10 @@ func (system *scs) Xor(a, b frontend.Variable) frontend.Variable { // Or returns a | b // a and b must be 0 or 1 func (system *scs) Or(a, b frontend.Variable) frontend.Variable { + + system.AssertIsBoolean(a) + system.AssertIsBoolean(b) + _a, aConstant := system.ConstantValue(a) _b, bConstant := system.ConstantValue(b) @@ -241,11 +248,6 @@ func (system *scs) Or(a, b frontend.Variable) frontend.Variable { l := a.(compiled.Term) r := l - if !(_b.IsUint64() && (_b.Uint64() <= 1)) { - panic(fmt.Sprintf("%s should be 0 or 1", _b.String())) - } - system.AssertIsBoolean(a) - one := big.NewInt(1) _b.Sub(_b, one) idl := system.st.CoeffID(_b) @@ -254,8 +256,6 @@ func (system *scs) Or(a, b frontend.Variable) frontend.Variable { } l := a.(compiled.Term) r := b.(compiled.Term) - system.AssertIsBoolean(l) - system.AssertIsBoolean(r) system.addPlonkConstraint(l, r, res, compiled.CoeffIdMinusOne, compiled.CoeffIdMinusOne, compiled.CoeffIdOne, compiled.CoeffIdOne, compiled.CoeffIdOne, compiled.CoeffIdZero) return res } diff --git a/internal/backend/bls12-377/cs/r1cs_sparse.go b/internal/backend/bls12-377/cs/r1cs_sparse.go index 711050605a..40ef429dae 100644 --- a/internal/backend/bls12-377/cs/r1cs_sparse.go +++ b/internal/backend/bls12-377/cs/r1cs_sparse.go @@ -66,7 +66,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS // witness: contains the input variables // it returns the full slice of wires func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]fr.Element, error) { - log := logger.Logger().With().Str("curve", cs.CurveID().String()).Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() + log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() // set the slices holding the solution.values and monitoring which variables have been solved nbVariables := cs.NbInternalVariables + cs.NbSecretVariables + cs.NbPublicVariables diff --git a/internal/backend/bls12-381/cs/r1cs_sparse.go b/internal/backend/bls12-381/cs/r1cs_sparse.go index f7ad4cf863..edcbd6cfff 100644 --- a/internal/backend/bls12-381/cs/r1cs_sparse.go +++ b/internal/backend/bls12-381/cs/r1cs_sparse.go @@ -66,7 +66,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS // witness: contains the input variables // it returns the full slice of wires func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]fr.Element, error) { - log := logger.Logger().With().Str("curve", cs.CurveID().String()).Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() + log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() // set the slices holding the solution.values and monitoring which variables have been solved nbVariables := cs.NbInternalVariables + cs.NbSecretVariables + cs.NbPublicVariables diff --git a/internal/backend/bls24-315/cs/r1cs_sparse.go b/internal/backend/bls24-315/cs/r1cs_sparse.go index 96eba3afa4..bc8c04500b 100644 --- a/internal/backend/bls24-315/cs/r1cs_sparse.go +++ b/internal/backend/bls24-315/cs/r1cs_sparse.go @@ -66,7 +66,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS // witness: contains the input variables // it returns the full slice of wires func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]fr.Element, error) { - log := logger.Logger().With().Str("curve", cs.CurveID().String()).Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() + log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() // set the slices holding the solution.values and monitoring which variables have been solved nbVariables := cs.NbInternalVariables + cs.NbSecretVariables + cs.NbPublicVariables diff --git a/internal/backend/bn254/cs/r1cs_sparse.go b/internal/backend/bn254/cs/r1cs_sparse.go index e49202718b..1c92e9a87e 100644 --- a/internal/backend/bn254/cs/r1cs_sparse.go +++ b/internal/backend/bn254/cs/r1cs_sparse.go @@ -66,7 +66,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS // witness: contains the input variables // it returns the full slice of wires func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]fr.Element, error) { - log := logger.Logger().With().Str("curve", cs.CurveID().String()).Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() + log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() // set the slices holding the solution.values and monitoring which variables have been solved nbVariables := cs.NbInternalVariables + cs.NbSecretVariables + cs.NbPublicVariables diff --git a/internal/backend/bw6-633/cs/r1cs_sparse.go b/internal/backend/bw6-633/cs/r1cs_sparse.go index 6744060867..60c6e54729 100644 --- a/internal/backend/bw6-633/cs/r1cs_sparse.go +++ b/internal/backend/bw6-633/cs/r1cs_sparse.go @@ -66,7 +66,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS // witness: contains the input variables // it returns the full slice of wires func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]fr.Element, error) { - log := logger.Logger().With().Str("curve", cs.CurveID().String()).Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() + log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() // set the slices holding the solution.values and monitoring which variables have been solved nbVariables := cs.NbInternalVariables + cs.NbSecretVariables + cs.NbPublicVariables diff --git a/internal/backend/bw6-761/cs/r1cs_sparse.go b/internal/backend/bw6-761/cs/r1cs_sparse.go index ba35ad197c..74697271c5 100644 --- a/internal/backend/bw6-761/cs/r1cs_sparse.go +++ b/internal/backend/bw6-761/cs/r1cs_sparse.go @@ -66,7 +66,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS // witness: contains the input variables // it returns the full slice of wires func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]fr.Element, error) { - log := logger.Logger().With().Str("curve", cs.CurveID().String()).Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() + log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() // set the slices holding the solution.values and monitoring which variables have been solved nbVariables := cs.NbInternalVariables + cs.NbSecretVariables + cs.NbPublicVariables diff --git a/internal/backend/circuits/range.go b/internal/backend/circuits/range.go index e73d491a00..48dfbbe75c 100644 --- a/internal/backend/circuits/range.go +++ b/internal/backend/circuits/range.go @@ -5,7 +5,7 @@ import ( "github.com/consensys/gnark/frontend" ) -const bound = 161 +const bound = 44 type rangeCheckConstantCircuit struct { X frontend.Variable @@ -24,8 +24,8 @@ func (circuit *rangeCheckConstantCircuit) Define(api frontend.API) error { func rangeCheckConstant() { var circuit, good, bad rangeCheckConstantCircuit - good.X = (10) - good.Y = (4) + good.X = (4) + good.Y = (2) bad.X = (11) bad.Y = (4) @@ -52,8 +52,8 @@ func rangeCheck() { var circuit, good, bad rangeCheckCircuit - good.X = (10) - good.Y = (4) + good.X = (4) + good.Y = (2) good.Bound = (bound) bad.X = (11) diff --git a/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl b/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl index 04b815cff6..1fc8ffd83c 100644 --- a/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl +++ b/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl @@ -48,7 +48,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS // witness: contains the input variables // it returns the full slice of wires func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]fr.Element, error) { - log := logger.Logger().With().Str("curve", cs.CurveID().String()).Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() + log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() // set the slices holding the solution.values and monitoring which variables have been solved nbVariables := cs.NbInternalVariables + cs.NbSecretVariables + cs.NbPublicVariables diff --git a/internal/tinyfield/cs/r1cs_sparse.go b/internal/tinyfield/cs/r1cs_sparse.go index 7d2d2b03c9..736f70b60c 100644 --- a/internal/tinyfield/cs/r1cs_sparse.go +++ b/internal/tinyfield/cs/r1cs_sparse.go @@ -66,7 +66,7 @@ func NewSparseR1CS(ccs compiled.SparseR1CS, coefficients []big.Int) *SparseR1CS // witness: contains the input variables // it returns the full slice of wires func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]fr.Element, error) { - log := logger.Logger().With().Str("curve", cs.CurveID().String()).Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() + log := logger.Logger().With().Int("nbConstraints", len(cs.Constraints)).Str("backend", "plonk").Logger() // set the slices holding the solution.values and monitoring which variables have been solved nbVariables := cs.NbInternalVariables + cs.NbSecretVariables + cs.NbPublicVariables diff --git a/test/solver_test.go b/test/solver_test.go new file mode 100644 index 0000000000..afd7996832 --- /dev/null +++ b/test/solver_test.go @@ -0,0 +1,250 @@ +package test + +import ( + "errors" + "fmt" + "math/big" + "reflect" + "strconv" + "strings" + "testing" + + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/debug" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" + "github.com/consensys/gnark/frontend/schema" + "github.com/consensys/gnark/internal/backend/circuits" + "github.com/consensys/gnark/internal/tinyfield" + "github.com/consensys/gnark/internal/tinyfield/cs" + "github.com/consensys/gnark/internal/utils" +) + +// ignore witness size larger than this bound +const permutterBound = 3 + +func TestSolverConsistency(t *testing.T) { + if testing.Short() { + t.Skip("skipping R1CS solver test with testing.Short() flag set") + return + } + + // idea is test circuits, we are going to test all possible values of the witness. + // (hence the choice of a small modulus for the field size) + // + // we generate witnesses and compare with the output of big.Int test engine against + // R1CS and SparseR1CS solvers + + for name := range circuits.Circuits { + t.Run(name, func(t *testing.T) { + tc := circuits.Circuits[name] + t.Parallel() + err := consistentSolver(tc.Circuit, tc.HintFunctions) + if err != nil { + t.Fatal(err) + } + }) + } +} + +type permutter struct { + circuit frontend.Circuit + r1cs *cs.R1CS + scs *cs.SparseR1CS + witness []tinyfield.Element + hints []hint.Function + + // used to avoid allocations in R1CS solver + a, b, c []tinyfield.Element +} + +// note that circuit will be mutated and this is not thread safe +func (p *permutter) permuteAndTest(index int) error { + + for i := 0; i < len(tinyfieldElements); i++ { + p.witness[index].SetUint64(tinyfieldElements[i]) + if index == len(p.witness)-1 { + // we have a unique permutation + + // solve the cs using R1CS solver + errR1CS := p.solveR1CS() + errSCS := p.solveSCS() + + // solve the cs using test engine + // first copy the witness in the circuit + copyWitnessFromVector(p.circuit, p.witness) + errEngine1 := isSolvedEngine(p.circuit, tinyfield.Modulus()) + + copyWitnessFromVector(p.circuit, p.witness) + errEngine2 := isSolvedEngine(p.circuit, tinyfield.Modulus(), SetAllVariablesAsConstants()) + + if (errR1CS == nil) != (errEngine1 == nil) || + (errSCS == nil) != (errEngine1 == nil) || + (errEngine1 == nil) != (errEngine2 == nil) { + return fmt.Errorf("errSCS :%s\nerrR1CS :%s\nerrEngine(const=false): %s\nerrEngine(const=true): %s\nwitness: %s", + formatError(errSCS), + formatError(errR1CS), + formatError(errEngine1), + formatError(errEngine2), + formatWitness(p.witness)) + } + } else { + // recurse + if err := p.permuteAndTest(index + 1); err != nil { + return err + } + } + } + return nil +} + +func formatError(err error) string { + if err == nil { + return "" + } + return err.Error() +} + +func formatWitness(witness []tinyfield.Element) string { + var sbb strings.Builder + sbb.WriteByte('[') + + for i := 0; i < len(witness); i++ { + sbb.WriteString(strconv.Itoa(int(witness[i].Uint64()))) + if i != len(witness)-1 { + sbb.WriteString(", ") + } + } + + sbb.WriteByte(']') + + return sbb.String() +} + +func (p *permutter) solveSCS() error { + opt, err := backend.NewProverConfig(backend.WithHints(p.hints...)) + if err != nil { + return err + } + + _, err = p.scs.Solve(p.witness, opt) + return err +} + +func (p *permutter) solveR1CS() error { + opt, err := backend.NewProverConfig(backend.WithHints(p.hints...)) + if err != nil { + return err + } + + for i := 0; i < len(p.r1cs.Constraints); i++ { + p.a[i].SetZero() + p.b[i].SetZero() + p.c[i].SetZero() + } + _, err = p.r1cs.Solve(p.witness, p.a, p.b, p.c, opt) + return err +} + +// isSolvedEngine behaves like test.IsSolved except it doesn't clone the circuit +func isSolvedEngine(c frontend.Circuit, field *big.Int, opts ...TestEngineOption) (err error) { + e := &engine{ + curveID: utils.FieldToCurve(field), + q: new(big.Int).Set(field), + apiWrapper: func(a frontend.API) frontend.API { return a }, + constVars: false, + } + for _, opt := range opts { + if err := opt(e); err != nil { + return fmt.Errorf("apply option: %w", err) + } + } + + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("%v\n%s", r, string(debug.Stack())) + } + }() + + api := e.apiWrapper(e) + err = c.Define(api) + + return +} + +// fill the "to" frontend.Circuit with values from the provided vector +// values are assumed to be ordered [public | secret] +func copyWitnessFromVector(to frontend.Circuit, from []tinyfield.Element) { + i := 0 + schema.Parse(to, tVariable, func(visibility schema.Visibility, name string, tInput reflect.Value) error { + if visibility == schema.Public { + tInput.Set(reflect.ValueOf((from[i]))) + i++ + } + return nil + }) + + schema.Parse(to, tVariable, func(visibility schema.Visibility, name string, tInput reflect.Value) error { + if visibility == schema.Secret { + tInput.Set(reflect.ValueOf((from[i]))) + i++ + } + return nil + }) +} + +// ConsistentSolver solves given circuit with all possible witness combinations using internal/tinyfield +// +// Since the goal of this method is to flag potential solver issues, it is not exposed as an API for now +func consistentSolver(circuit frontend.Circuit, hintFunctions []hint.Function) error { + + p := permutter{ + circuit: circuit, + hints: hintFunctions, + } + + // compile R1CS + ccs, err := frontend.Compile(tinyfield.Modulus(), r1cs.NewBuilder, circuit) + if err != nil { + return err + } + + p.r1cs = ccs.(*cs.R1CS) + + // witness len + n := p.r1cs.NbPublicVariables - 1 + p.r1cs.NbSecretVariables + if n > permutterBound { + return nil + } + + p.a = make([]tinyfield.Element, p.r1cs.GetNbConstraints()) + p.b = make([]tinyfield.Element, p.r1cs.GetNbConstraints()) + p.c = make([]tinyfield.Element, p.r1cs.GetNbConstraints()) + p.witness = make([]tinyfield.Element, n) + + // compile SparseR1CS + ccs, err = frontend.Compile(tinyfield.Modulus(), scs.NewBuilder, circuit) + if err != nil { + return err + } + + p.scs = ccs.(*cs.SparseR1CS) + if (p.scs.NbPublicVariables + p.scs.NbSecretVariables) != n { + return errors.New("mismatch of witness size for same circuit") + } + + return p.permuteAndTest(0) +} + +// [0, 1, ..., q - 1], with q == tinyfield.Modulus() +var tinyfieldElements []uint64 + +func init() { + n := tinyfield.Modulus().Uint64() + tinyfieldElements = make([]uint64, n) + for i := uint64(0); i < n; i++ { + tinyfieldElements[i] = i + } +}