From 4b9e258f40af1944e89f8a005ec885e88848e9d8 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 18 Feb 2022 13:29:32 -0600 Subject: [PATCH 01/20] feat: hint.NbOuputs should not be used at solve time, only at compile time --- internal/backend/bls12-377/cs/solution.go | 2 +- internal/backend/bls12-381/cs/solution.go | 2 +- internal/backend/bls24-315/cs/solution.go | 2 +- internal/backend/bn254/cs/solution.go | 2 +- internal/backend/bw6-633/cs/solution.go | 2 +- internal/backend/bw6-761/cs/solution.go | 2 +- .../generator/backend/template/representations/solution.go.tmpl | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/internal/backend/bls12-377/cs/solution.go b/internal/backend/bls12-377/cs/solution.go index 6911e72ab7..b8b7dd02bb 100644 --- a/internal/backend/bls12-377/cs/solution.go +++ b/internal/backend/bls12-377/cs/solution.go @@ -148,7 +148,7 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { // tmp IO big int memory nbInputs := len(h.Inputs) - nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs)) + nbOutputs := len(h.Wires) // m := len(s.tmpHintsIO) // if m < (nbInputs + nbOutputs) { // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...) diff --git a/internal/backend/bls12-381/cs/solution.go b/internal/backend/bls12-381/cs/solution.go index 9d630c8153..cb0964161d 100644 --- a/internal/backend/bls12-381/cs/solution.go +++ b/internal/backend/bls12-381/cs/solution.go @@ -148,7 +148,7 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { // tmp IO big int memory nbInputs := len(h.Inputs) - nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs)) + nbOutputs := len(h.Wires) // m := len(s.tmpHintsIO) // if m < (nbInputs + nbOutputs) { // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...) diff --git a/internal/backend/bls24-315/cs/solution.go b/internal/backend/bls24-315/cs/solution.go index e215ac343a..179b3490ec 100644 --- a/internal/backend/bls24-315/cs/solution.go +++ b/internal/backend/bls24-315/cs/solution.go @@ -148,7 +148,7 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { // tmp IO big int memory nbInputs := len(h.Inputs) - nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs)) + nbOutputs := len(h.Wires) // m := len(s.tmpHintsIO) // if m < (nbInputs + nbOutputs) { // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...) diff --git a/internal/backend/bn254/cs/solution.go b/internal/backend/bn254/cs/solution.go index 46cb2eb6af..2c526afcd1 100644 --- a/internal/backend/bn254/cs/solution.go +++ b/internal/backend/bn254/cs/solution.go @@ -148,7 +148,7 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { // tmp IO big int memory nbInputs := len(h.Inputs) - nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs)) + nbOutputs := len(h.Wires) // m := len(s.tmpHintsIO) // if m < (nbInputs + nbOutputs) { // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...) diff --git a/internal/backend/bw6-633/cs/solution.go b/internal/backend/bw6-633/cs/solution.go index 4b611e7f07..f4e08474ef 100644 --- a/internal/backend/bw6-633/cs/solution.go +++ b/internal/backend/bw6-633/cs/solution.go @@ -148,7 +148,7 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { // tmp IO big int memory nbInputs := len(h.Inputs) - nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs)) + nbOutputs := len(h.Wires) // m := len(s.tmpHintsIO) // if m < (nbInputs + nbOutputs) { // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...) diff --git a/internal/backend/bw6-761/cs/solution.go b/internal/backend/bw6-761/cs/solution.go index fb1e1a19bd..162ab02f8f 100644 --- a/internal/backend/bw6-761/cs/solution.go +++ b/internal/backend/bw6-761/cs/solution.go @@ -148,7 +148,7 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { // tmp IO big int memory nbInputs := len(h.Inputs) - nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs)) + nbOutputs := len(h.Wires) // m := len(s.tmpHintsIO) // if m < (nbInputs + nbOutputs) { // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...) diff --git a/internal/generator/backend/template/representations/solution.go.tmpl b/internal/generator/backend/template/representations/solution.go.tmpl index dba02f7cb6..a45ae12b63 100644 --- a/internal/generator/backend/template/representations/solution.go.tmpl +++ b/internal/generator/backend/template/representations/solution.go.tmpl @@ -129,7 +129,7 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { // tmp IO big int memory nbInputs := len(h.Inputs) - nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs)) + nbOutputs := len(h.Wires) // m := len(s.tmpHintsIO) // if m < (nbInputs + nbOutputs) { // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...) From f322f8cab0a0e8a2886e0d2bbbcae91778c11430 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 18 Feb 2022 13:45:56 -0600 Subject: [PATCH 02/20] feat: make nboutputs of a hint explicit at compile time --- backend/hint/hint.go | 13 ++++++++----- frontend/api.go | 4 +++- frontend/cs/plonk/api.go | 10 +++++----- frontend/cs/r1cs/api.go | 10 +++++----- internal/backend/circuits/hint.go | 6 +++--- std/algebra/sw_bls12377/g1.go | 2 +- std/algebra/sw_bls24315/g1.go | 2 +- test/engine.go | 6 +++--- test/engine_test.go | 8 ++++---- 9 files changed, 33 insertions(+), 28 deletions(-) diff --git a/backend/hint/hint.go b/backend/hint/hint.go index a68f768b8d..e3f7042599 100644 --- a/backend/hint/hint.go +++ b/backend/hint/hint.go @@ -96,10 +96,13 @@ type Function interface { // obtained from cache). A returned non-nil error will be propagated. Call(curveID ecc.ID, inputs []*big.Int, res []*big.Int) error - // NbOutputs returns the total number of outputs by the function when + // NbOutputs returns the MAX total number of outputs by the function when + // TODO @gbotrel @ivokub --> this should be used at compile time only. + // at solving time, what makes law is the number of wires associated with the hints + // assuming they were correctly allocated in the compile phase. // invoked on the curveID with nInputs number of inputs. The number of // outputs must be at least one and the framework errors otherwise. - NbOutputs(curveID ecc.ID, nInputs int) (nOutputs int) + // NbOutputs(curveID ecc.ID, nInputs int) (nOutputs int) // String returns a human-readable description of the function used in logs // and debug messages. @@ -157,9 +160,9 @@ func (h *staticArgumentsFunction) Call(curveID ecc.ID, inputs []*big.Int, res [] return h.fn(curveID, inputs, res) } -func (h *staticArgumentsFunction) NbOutputs(_ ecc.ID, _ int) int { - return h.nOut -} +// func (h *staticArgumentsFunction) NbOutputs(_ ecc.ID, _ int) int { +// return h.nOut +// } func (h *staticArgumentsFunction) UUID() ID { return UUID(h.fn, uint64(h.nIn), uint64(h.nOut)) diff --git a/frontend/api.go b/frontend/api.go index 3e7fde87c0..95c4298bc8 100644 --- a/frontend/api.go +++ b/frontend/api.go @@ -123,7 +123,9 @@ type API interface { // // No new constraints are added to the newly created wire and must be added // manually in the circuit. Failing to do so leads to solver failure. - NewHint(f hint.Function, inputs ...Variable) ([]Variable, error) + // + // If nbOutputs is specified, it must be >= 1 and <= f.NbOutputs + NewHint(f hint.Function, nbOutputs int, inputs ...Variable) ([]Variable, error) // Tag creates a tag at a given place in a circuit. The state of the tag may contain informations needed to // measure constraints, variables and coefficients creations through AddCounter diff --git a/frontend/cs/plonk/api.go b/frontend/cs/plonk/api.go index 117dc7744a..8bbf8e3034 100644 --- a/frontend/cs/plonk/api.go +++ b/frontend/cs/plonk/api.go @@ -208,7 +208,7 @@ func (system *sparseR1CS) toBinary(a compiled.Term, nbBits int, unsafe bool) []f var c big.Int c.SetUint64(1) for i := 0; i < nbBits; i++ { - res, err := system.NewHint(hint.IthBit, a, i) + res, err := system.NewHint(hint.IthBit, 1, a, i) if err != nil { panic(err) } @@ -413,7 +413,7 @@ func (system *sparseR1CS) IsZero(i1 frontend.Variable) frontend.Variable { // a * m = 0 // constrain m to be 0 if a != 0 // _ = inverse(m + a) // constrain m to be 1 if a == 0 a := i1.(compiled.Term) - res, err := system.NewHint(hint.IsZero, a) + res, err := system.NewHint(hint.IsZero, 1, a) if err != nil { // the function errs only if the number of inputs is invalid. panic(err) @@ -573,9 +573,9 @@ func (system *sparseR1CS) AddCounter(from, to frontend.Tag) { // // No new constraints are added to the newly created wire and must be added // manually in the circuit. Failing to do so leads to solver failure. -func (system *sparseR1CS) NewHint(f hint.Function, inputs ...frontend.Variable) ([]frontend.Variable, error) { +func (system *sparseR1CS) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { - if f.NbOutputs(system.Curve(), len(inputs)) <= 0 { + if nbOutputs <= 0 { return nil, fmt.Errorf("hint function must return at least one output") } @@ -592,7 +592,7 @@ func (system *sparseR1CS) NewHint(f hint.Function, inputs ...frontend.Variable) } // prepare wires - varIDs := make([]int, f.NbOutputs(system.Curve(), len(inputs))) + varIDs := make([]int, nbOutputs) res := make([]frontend.Variable, len(varIDs)) for i := range varIDs { r := system.newInternalVariable() diff --git a/frontend/cs/r1cs/api.go b/frontend/cs/r1cs/api.go index 9fff35cd11..79d1f2fd7e 100644 --- a/frontend/cs/r1cs/api.go +++ b/frontend/cs/r1cs/api.go @@ -306,7 +306,7 @@ func (system *r1CS) toBinary(a compiled.Variable, nbBits int, unsafe bool) []fro var c big.Int c.SetUint64(1) for i := 0; i < nbBits; i++ { - res, err := system.NewHint(hint.IthBit, a, i) + res, err := system.NewHint(hint.IthBit, 1, a, i) if err != nil { panic(err) } @@ -523,7 +523,7 @@ func (system *r1CS) IsZero(i1 frontend.Variable) frontend.Variable { // _ = inverse(m + a) // constrain m to be 1 if a == 0 // m is computed by the solver such that m = 1 - a^(modulus - 1) - res, err := system.NewHint(hint.IsZero, a) + res, err := system.NewHint(hint.IsZero, 1, a) if err != nil { // the function errs only if the number of inputs is invalid. panic(err) @@ -709,9 +709,9 @@ func (system *r1CS) AddCounter(from, to frontend.Tag) { // // No new constraints are added to the newly created wire and must be added // manually in the circuit. Failing to do so leads to solver failure. -func (system *r1CS) NewHint(f hint.Function, inputs ...frontend.Variable) ([]frontend.Variable, error) { +func (system *r1CS) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { - if f.NbOutputs(system.Curve(), len(inputs)) <= 0 { + if nbOutputs <= 0 { return nil, fmt.Errorf("hint function must return at least one output") } hintInputs := make([]interface{}, len(inputs)) @@ -732,7 +732,7 @@ func (system *r1CS) NewHint(f hint.Function, inputs ...frontend.Variable) ([]fro } // prepare wires - varIDs := make([]int, f.NbOutputs(system.Curve(), len(inputs))) + varIDs := make([]int, nbOutputs) res := make([]frontend.Variable, len(varIDs)) for i := range varIDs { r := system.newInternalVariable() diff --git a/internal/backend/circuits/hint.go b/internal/backend/circuits/hint.go index d6eec90424..caecaf6d51 100644 --- a/internal/backend/circuits/hint.go +++ b/internal/backend/circuits/hint.go @@ -14,7 +14,7 @@ type hintCircuit struct { } func (circuit *hintCircuit) Define(api frontend.API) error { - res, err := api.NewHint(mulBy7, circuit.A) + res, err := api.NewHint(mulBy7, 1, circuit.A) if err != nil { return fmt.Errorf("mulBy7: %w", err) } @@ -23,7 +23,7 @@ func (circuit *hintCircuit) Define(api frontend.API) error { api.AssertIsEqual(a7, _a7) api.AssertIsEqual(a7, circuit.B) - res, err = api.NewHint(make3) + res, err = api.NewHint(make3, 1) if err != nil { return fmt.Errorf("make3: %w", err) } @@ -39,7 +39,7 @@ type vectorDoubleCircuit struct { } func (c *vectorDoubleCircuit) Define(api frontend.API) error { - res, err := api.NewHint(dvHint, c.A...) + res, err := api.NewHint(dvHint, len(c.B), c.A...) if err != nil { return fmt.Errorf("double newhint: %w", err) } diff --git a/std/algebra/sw_bls12377/g1.go b/std/algebra/sw_bls12377/g1.go index 7494b8d594..fc8c945ed4 100644 --- a/std/algebra/sw_bls12377/g1.go +++ b/std/algebra/sw_bls12377/g1.go @@ -254,7 +254,7 @@ func (P *G1Affine) varScalarMul(api frontend.API, Q G1Affine, s frontend.Variabl // the hints allow to decompose the scalar s into s1 and s2 such that // s1 + λ * s2 == s mod r, // where λ is third root of one in 𝔽_r. - sd, err := api.NewHint(scalarDecompositionHintBLS12377, s) + sd, err := api.NewHint(scalarDecompositionHintBLS12377, 3, s) if err != nil { // err is non-nil only for invalid number of inputs panic(err) diff --git a/std/algebra/sw_bls24315/g1.go b/std/algebra/sw_bls24315/g1.go index 4afde9a1cd..e613a6b08d 100644 --- a/std/algebra/sw_bls24315/g1.go +++ b/std/algebra/sw_bls24315/g1.go @@ -254,7 +254,7 @@ func (P *G1Affine) varScalarMul(api frontend.API, Q G1Affine, s frontend.Variabl // the hints allow to decompose the scalar s into s1 and s2 such that // s1 + λ * s2 == s mod r, // where λ is third root of one in 𝔽_r. - sd, err := api.NewHint(scalarDecompositionHintBLS24315, s) + sd, err := api.NewHint(scalarDecompositionHintBLS24315, 3, s) if err != nil { // err is non-nil only for invalid number of inputs panic(err) diff --git a/test/engine.go b/test/engine.go index 9469dedd06..f623ad08ce 100644 --- a/test/engine.go +++ b/test/engine.go @@ -325,9 +325,9 @@ func (e *engine) Println(a ...frontend.Variable) { fmt.Println(sbb.String()) } -func (e *engine) NewHint(f hint.Function, inputs ...frontend.Variable) ([]frontend.Variable, error) { +func (e *engine) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { - if f.NbOutputs(e.Curve(), len(inputs)) <= 0 { + if nbOutputs <= 0 { return nil, fmt.Errorf("hint function must return at least one output") } @@ -337,7 +337,7 @@ func (e *engine) NewHint(f hint.Function, inputs ...frontend.Variable) ([]fronte v := e.toBigInt(inputs[i]) in[i] = &v } - res := make([]*big.Int, f.NbOutputs(e.Curve(), len(inputs))) + res := make([]*big.Int, nbOutputs) for i := range res { res[i] = new(big.Int) } diff --git a/test/engine_test.go b/test/engine_test.go index a40119daf3..851e5ef972 100644 --- a/test/engine_test.go +++ b/test/engine_test.go @@ -15,22 +15,22 @@ type hintCircuit struct { } func (circuit *hintCircuit) Define(api frontend.API) error { - res, err := api.NewHint(hint.IthBit, circuit.A, 3) + res, err := api.NewHint(hint.IthBit, 1, circuit.A, 3) if err != nil { return fmt.Errorf("IthBit circuitA 3: %w", err) } a3b := res[0] - res, err = api.NewHint(hint.IthBit, circuit.A, 25) + res, err = api.NewHint(hint.IthBit, 1, circuit.A, 25) if err != nil { return fmt.Errorf("IthBit circuitA 25: %w", err) } a25b := res[0] - res, err = api.NewHint(hint.IsZero, circuit.A) + res, err = api.NewHint(hint.IsZero, 1, circuit.A) if err != nil { return fmt.Errorf("IsZero CircuitA: %w", err) } aisZero := res[0] - res, err = api.NewHint(hint.IsZero, circuit.B) + res, err = api.NewHint(hint.IsZero, 1, circuit.B) if err != nil { return fmt.Errorf("IsZero, CircuitB") } From 5c716b75acc186681d08189cd8e595a151238a65 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 18 Feb 2022 13:59:15 -0600 Subject: [PATCH 03/20] feat: added NBits hint --- backend/hint/builtin.go | 17 +++++++++++++++-- backend/hint/hint.go | 19 +++++++------------ frontend/cs/plonk/api.go | 18 +++++++++--------- frontend/cs/r1cs/api.go | 18 +++++++++--------- internal/backend/circuits/hint.go | 4 ++-- std/algebra/sw_bls12377/g1.go | 2 +- std/algebra/sw_bls24315/g1.go | 2 +- 7 files changed, 44 insertions(+), 36 deletions(-) diff --git a/backend/hint/builtin.go b/backend/hint/builtin.go index 9fcf91771f..aaae5568e8 100644 --- a/backend/hint/builtin.go +++ b/backend/hint/builtin.go @@ -11,10 +11,12 @@ var initBuiltinOnce sync.Once func init() { initBuiltinOnce.Do(func() { - IsZero = NewStaticHint(builtinIsZero, 1, 1) + IsZero = NewStaticHint(builtinIsZero, 1) Register(IsZero) - IthBit = NewStaticHint(builtinIthBit, 2, 1) + IthBit = NewStaticHint(builtinIthBit, 2) Register(IthBit) + NBits = NewStaticHint(builtinNBits, 1) + Register(NBits) }) } @@ -30,6 +32,9 @@ var ( // integer inputs i and n, takes the little-endian bit representation of n and // returns its i-th bit. IthBit Function + + // NBits returns the n first bits of the input. Expects one argument: n. + NBits Function ) func builtinIsZero(curveID ecc.ID, inputs []*big.Int, results []*big.Int) error { @@ -63,3 +68,11 @@ func builtinIthBit(_ ecc.ID, inputs []*big.Int, results []*big.Int) error { result.SetUint64(uint64(inputs[0].Bit(int(inputs[1].Uint64())))) return nil } + +func builtinNBits(_ ecc.ID, inputs []*big.Int, results []*big.Int) error { + n := inputs[0] + for i := 0; i < len(results); i++ { + results[i].SetUint64(uint64(n.Bit(i))) + } + return nil +} diff --git a/backend/hint/hint.go b/backend/hint/hint.go index e3f7042599..95a38c6c17 100644 --- a/backend/hint/hint.go +++ b/backend/hint/hint.go @@ -133,20 +133,18 @@ func UUID(fn StaticFunction, ctx ...uint64) ID { // staticArgumentsFunction defines a function where the number of inputs and // outputs is constant. type staticArgumentsFunction struct { - fn StaticFunction - nIn int - nOut int + fn StaticFunction + nIn int } // NewStaticHint returns an Function where the number of inputs and outputs is // constant. UUID is computed by combining fn, nIn and nOut and thus it is legal // to defined multiple AnnotatedFunctions on the same fn with different nIn and // nOut. -func NewStaticHint(fn StaticFunction, nIn, nOut int) Function { +func NewStaticHint(fn StaticFunction, nIn int) Function { return &staticArgumentsFunction{ - fn: fn, - nIn: nIn, - nOut: nOut, + fn: fn, + nIn: nIn, } } @@ -154,9 +152,6 @@ func (h *staticArgumentsFunction) Call(curveID ecc.ID, inputs []*big.Int, res [] if len(inputs) != h.nIn { return fmt.Errorf("input has %d elements, expected %d", len(inputs), h.nIn) } - if len(res) != h.nOut { - return fmt.Errorf("result has %d elements, expected %d", len(res), h.nOut) - } return h.fn(curveID, inputs, res) } @@ -165,11 +160,11 @@ func (h *staticArgumentsFunction) Call(curveID ecc.ID, inputs []*big.Int, res [] // } func (h *staticArgumentsFunction) UUID() ID { - return UUID(h.fn, uint64(h.nIn), uint64(h.nOut)) + return UUID(h.fn, uint64(h.nIn)) } func (h *staticArgumentsFunction) String() string { fnptr := reflect.ValueOf(h.fn).Pointer() name := runtime.FuncForPC(fnptr).Name() - return fmt.Sprintf("%s([%d]*big.Int, [%d]*big.Int) at (%x)", name, h.nIn, h.nOut, fnptr) + return fmt.Sprintf("%s([%d]*big.Int, [?]*big.Int) at (%x)", name, h.nIn, fnptr) } diff --git a/frontend/cs/plonk/api.go b/frontend/cs/plonk/api.go index 8bbf8e3034..385d483efe 100644 --- a/frontend/cs/plonk/api.go +++ b/frontend/cs/plonk/api.go @@ -203,20 +203,20 @@ func (system *sparseR1CS) ToBinary(i1 frontend.Variable, n ...int) []frontend.Va func (system *sparseR1CS) toBinary(a compiled.Term, nbBits int, unsafe bool) []frontend.Variable { // allocate the resulting frontend.Variables and bit-constraint them - b := make([]frontend.Variable, nbBits) sb := make([]frontend.Variable, nbBits) var c big.Int c.SetUint64(1) + + bits, err := system.NewHint(hint.NBits, nbBits, a) + if err != nil { + panic(err) + } + for i := 0; i < nbBits; i++ { - res, err := system.NewHint(hint.IthBit, 1, a, i) - if err != nil { - panic(err) - } - b[i] = res[0] - sb[i] = system.Mul(b[i], c) + sb[i] = system.Mul(bits[i], c) c.Lsh(&c, 1) if !unsafe { - system.AssertIsBoolean(b[i]) + system.AssertIsBoolean(bits[i]) } } @@ -233,7 +233,7 @@ func (system *sparseR1CS) toBinary(a compiled.Term, nbBits int, unsafe bool) []f system.AssertIsEqual(Σbi, a) // record the constraint Σ (2**i * b[i]) == a - return b + return bits } diff --git a/frontend/cs/r1cs/api.go b/frontend/cs/r1cs/api.go index 79d1f2fd7e..e6f06ce614 100644 --- a/frontend/cs/r1cs/api.go +++ b/frontend/cs/r1cs/api.go @@ -301,20 +301,20 @@ func (system *r1CS) toBinary(a compiled.Variable, nbBits int, unsafe bool) []fro a.AssertIsSet() // allocate the resulting frontend.Variables and bit-constraint them - b := make([]frontend.Variable, nbBits) sb := make([]frontend.Variable, nbBits) var c big.Int c.SetUint64(1) + + bits, err := system.NewHint(hint.NBits, nbBits, a) + if err != nil { + panic(err) + } + for i := 0; i < nbBits; i++ { - res, err := system.NewHint(hint.IthBit, 1, a, i) - if err != nil { - panic(err) - } - b[i] = res[0] - sb[i] = system.Mul(b[i], c) + sb[i] = system.Mul(bits[i], c) c.Lsh(&c, 1) if !unsafe { - system.AssertIsBoolean(b[i]) + system.AssertIsBoolean(bits[i]) } } @@ -330,7 +330,7 @@ func (system *r1CS) toBinary(a compiled.Variable, nbBits int, unsafe bool) []fro system.AssertIsEqual(Σbi, a) // record the constraint Σ (2**i * b[i]) == a - return b + return bits } diff --git a/internal/backend/circuits/hint.go b/internal/backend/circuits/hint.go index caecaf6d51..b3bc1c5070 100644 --- a/internal/backend/circuits/hint.go +++ b/internal/backend/circuits/hint.go @@ -101,12 +101,12 @@ func init() { var mulBy7 = hint.NewStaticHint(func(curveID ecc.ID, inputs []*big.Int, result []*big.Int) error { result[0].Mul(inputs[0], big.NewInt(7)).Mod(result[0], curveID.Info().Fr.Modulus()) return nil -}, 1, 1) +}, 1) var make3 = hint.NewStaticHint(func(curveID ecc.ID, inputs []*big.Int, result []*big.Int) error { result[0].SetUint64(3) return nil -}, 0, 1) +}, 0) var dvHint = &doubleVector{} diff --git a/std/algebra/sw_bls12377/g1.go b/std/algebra/sw_bls12377/g1.go index fc8c945ed4..496a6c7d59 100644 --- a/std/algebra/sw_bls12377/g1.go +++ b/std/algebra/sw_bls12377/g1.go @@ -225,7 +225,7 @@ var scalarDecompositionHintBLS12377 = hint.NewStaticHint(func(curve ecc.ID, inpu res[2].Div(res[2], cc.fr) return nil -}, 1, 3) +}, 1) func init() { hint.Register(scalarDecompositionHintBLS12377) diff --git a/std/algebra/sw_bls24315/g1.go b/std/algebra/sw_bls24315/g1.go index e613a6b08d..dc079df28f 100644 --- a/std/algebra/sw_bls24315/g1.go +++ b/std/algebra/sw_bls24315/g1.go @@ -225,7 +225,7 @@ var scalarDecompositionHintBLS24315 = hint.NewStaticHint(func(curve ecc.ID, inpu res[2].Div(res[2], cc.fr) return nil -}, 1, 3) +}, 1) func init() { hint.Register(scalarDecompositionHintBLS24315) From 0d1da2493bdf3aa39be51e554d03a335563217cf Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Thu, 24 Feb 2022 16:37:17 -0600 Subject: [PATCH 04/20] refactor: remove nb inputs from hint declaration --- backend/hint/builtin.go | 6 ++-- backend/hint/hint.go | 55 ++++++++----------------------- internal/backend/circuits/hint.go | 4 +-- std/algebra/sw_bls12377/g1.go | 2 +- std/algebra/sw_bls24315/g1.go | 2 +- 5 files changed, 20 insertions(+), 49 deletions(-) diff --git a/backend/hint/builtin.go b/backend/hint/builtin.go index aaae5568e8..451301c000 100644 --- a/backend/hint/builtin.go +++ b/backend/hint/builtin.go @@ -11,11 +11,11 @@ var initBuiltinOnce sync.Once func init() { initBuiltinOnce.Do(func() { - IsZero = NewStaticHint(builtinIsZero, 1) + IsZero = NewStaticHint(builtinIsZero) Register(IsZero) - IthBit = NewStaticHint(builtinIthBit, 2) + IthBit = NewStaticHint(builtinIthBit) Register(IthBit) - NBits = NewStaticHint(builtinNBits, 1) + NBits = NewStaticHint(builtinNBits) Register(NBits) }) } diff --git a/backend/hint/hint.go b/backend/hint/hint.go index 95a38c6c17..7f2ab90115 100644 --- a/backend/hint/hint.go +++ b/backend/hint/hint.go @@ -81,7 +81,7 @@ type ID uint32 // StaticFunction is a function which takes a constant number of inputs and // returns a constant number of outputs. Use NewStaticHint() to construct an // instance compatible with Function interface. -type StaticFunction func(curveID ecc.ID, inputs []*big.Int, res []*big.Int) error +type StaticFunction func(curveID ecc.ID, inputs []*big.Int, outputs []*big.Int) error // Function defines an annotated hint function. To initialize a hint function // with static number of inputs and outputs, use NewStaticHint(). @@ -94,21 +94,17 @@ type Function interface { // The length of res is NbOutputs() and every element is // already initialized (but not necessarily to zero as the elements may be // obtained from cache). A returned non-nil error will be propagated. - Call(curveID ecc.ID, inputs []*big.Int, res []*big.Int) error - - // NbOutputs returns the MAX total number of outputs by the function when - // TODO @gbotrel @ivokub --> this should be used at compile time only. - // at solving time, what makes law is the number of wires associated with the hints - // assuming they were correctly allocated in the compile phase. - // invoked on the curveID with nInputs number of inputs. The number of - // outputs must be at least one and the framework errors otherwise. - // NbOutputs(curveID ecc.ID, nInputs int) (nOutputs int) + Call(curveID ecc.ID, inputs []*big.Int, outputs []*big.Int) error // String returns a human-readable description of the function used in logs // and debug messages. String() string } +func NewStaticHint(fn StaticFunction) Function { + return fn +} + // UUID is a reference function for computing the hint ID based on a function // and additional context values ctx. A change in any of the inputs modifies the // returned value and thus this function can be used to compute the hint ID for @@ -130,41 +126,16 @@ func UUID(fn StaticFunction, ctx ...uint64) ID { return ID(hf.Sum32()) } -// staticArgumentsFunction defines a function where the number of inputs and -// outputs is constant. -type staticArgumentsFunction struct { - fn StaticFunction - nIn int -} - -// NewStaticHint returns an Function where the number of inputs and outputs is -// constant. UUID is computed by combining fn, nIn and nOut and thus it is legal -// to defined multiple AnnotatedFunctions on the same fn with different nIn and -// nOut. -func NewStaticHint(fn StaticFunction, nIn int) Function { - return &staticArgumentsFunction{ - fn: fn, - nIn: nIn, - } +func (h StaticFunction) Call(curveID ecc.ID, inputs []*big.Int, res []*big.Int) error { + return h(curveID, inputs, res) } -func (h *staticArgumentsFunction) Call(curveID ecc.ID, inputs []*big.Int, res []*big.Int) error { - if len(inputs) != h.nIn { - return fmt.Errorf("input has %d elements, expected %d", len(inputs), h.nIn) - } - return h.fn(curveID, inputs, res) -} - -// func (h *staticArgumentsFunction) NbOutputs(_ ecc.ID, _ int) int { -// return h.nOut -// } - -func (h *staticArgumentsFunction) UUID() ID { - return UUID(h.fn, uint64(h.nIn)) +func (h StaticFunction) UUID() ID { + return UUID(h) } -func (h *staticArgumentsFunction) String() string { - fnptr := reflect.ValueOf(h.fn).Pointer() +func (h StaticFunction) String() string { + fnptr := reflect.ValueOf(h).Pointer() name := runtime.FuncForPC(fnptr).Name() - return fmt.Sprintf("%s([%d]*big.Int, [?]*big.Int) at (%x)", name, h.nIn, fnptr) + return fmt.Sprintf("%s([?]*big.Int, [?]*big.Int) at (%x)", name, fnptr) } diff --git a/internal/backend/circuits/hint.go b/internal/backend/circuits/hint.go index b3bc1c5070..3e64e8cf97 100644 --- a/internal/backend/circuits/hint.go +++ b/internal/backend/circuits/hint.go @@ -101,12 +101,12 @@ func init() { var mulBy7 = hint.NewStaticHint(func(curveID ecc.ID, inputs []*big.Int, result []*big.Int) error { result[0].Mul(inputs[0], big.NewInt(7)).Mod(result[0], curveID.Info().Fr.Modulus()) return nil -}, 1) +}) var make3 = hint.NewStaticHint(func(curveID ecc.ID, inputs []*big.Int, result []*big.Int) error { result[0].SetUint64(3) return nil -}, 0) +}) var dvHint = &doubleVector{} diff --git a/std/algebra/sw_bls12377/g1.go b/std/algebra/sw_bls12377/g1.go index 496a6c7d59..208f299d13 100644 --- a/std/algebra/sw_bls12377/g1.go +++ b/std/algebra/sw_bls12377/g1.go @@ -225,7 +225,7 @@ var scalarDecompositionHintBLS12377 = hint.NewStaticHint(func(curve ecc.ID, inpu res[2].Div(res[2], cc.fr) return nil -}, 1) +}) func init() { hint.Register(scalarDecompositionHintBLS12377) diff --git a/std/algebra/sw_bls24315/g1.go b/std/algebra/sw_bls24315/g1.go index dc079df28f..e18f11550b 100644 --- a/std/algebra/sw_bls24315/g1.go +++ b/std/algebra/sw_bls24315/g1.go @@ -225,7 +225,7 @@ var scalarDecompositionHintBLS24315 = hint.NewStaticHint(func(curve ecc.ID, inpu res[2].Div(res[2], cc.fr) return nil -}, 1) +}) func init() { hint.Register(scalarDecompositionHintBLS24315) From b61f2418294aa97214e0ab7e963815106d355386 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Thu, 24 Feb 2022 16:40:33 -0600 Subject: [PATCH 05/20] docs: clean up hint interface comment --- backend/hint/hint.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/backend/hint/hint.go b/backend/hint/hint.go index 7f2ab90115..a21bff6134 100644 --- a/backend/hint/hint.go +++ b/backend/hint/hint.go @@ -91,9 +91,7 @@ type Function interface { UUID() ID // Call is invoked by the framework to obtain the result from inputs. - // The length of res is NbOutputs() and every element is - // already initialized (but not necessarily to zero as the elements may be - // obtained from cache). A returned non-nil error will be propagated. + // Elements in outputs are not guaranteed to be initialized to 0 Call(curveID ecc.ID, inputs []*big.Int, outputs []*big.Int) error // String returns a human-readable description of the function used in logs From a9248bfde50cb3b20212715947820e2fe5b30361 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 25 Feb 2022 10:17:16 -0600 Subject: [PATCH 06/20] refactor: moved internal/compiled to frontend/compiled --- backend/hint/builtin.go | 2 + frontend/api.go | 1 + frontend/ccs.go | 2 +- frontend/compiled/cs.go | 143 ++++++++++++++++++ .../cs.go => frontend/compiled/hint.go | 79 ---------- .../backend => frontend}/compiled/log.go | 0 .../r1c.go => frontend/compiled/r1cs.go | 11 ++ .../compiled/r1cs_sparse.go | 11 ++ .../backend => frontend}/compiled/term.go | 49 +++++- .../backend => frontend}/compiled/variable.go | 40 ----- frontend/cs/builder.go | 66 ++++++++ frontend/cs/cs.go | 135 ----------------- frontend/cs/plonk/api.go | 16 +- frontend/cs/plonk/assertions.go | 10 +- frontend/cs/plonk/conversion.go | 18 +-- frontend/cs/plonk/sparse_r1cs.go | 45 +++--- frontend/cs/r1cs/api.go | 10 +- frontend/cs/r1cs/assertions.go | 2 +- frontend/cs/r1cs/conversion.go | 18 +-- frontend/cs/r1cs/r1cs.go | 50 +++--- frontend/cs/r1cs/r1cs_test.go | 2 +- internal/backend/bls12-377/cs/r1cs.go | 2 +- internal/backend/bls12-377/cs/r1cs_sparse.go | 2 +- internal/backend/bls12-377/cs/solution.go | 2 +- internal/backend/bls12-377/groth16/setup.go | 2 +- internal/backend/bls12-381/cs/r1cs.go | 2 +- internal/backend/bls12-381/cs/r1cs_sparse.go | 2 +- internal/backend/bls12-381/cs/solution.go | 2 +- internal/backend/bls12-381/groth16/setup.go | 2 +- internal/backend/bls24-315/cs/r1cs.go | 2 +- internal/backend/bls24-315/cs/r1cs_sparse.go | 2 +- internal/backend/bls24-315/cs/solution.go | 2 +- internal/backend/bls24-315/groth16/setup.go | 2 +- internal/backend/bn254/cs/r1cs.go | 2 +- internal/backend/bn254/cs/r1cs_sparse.go | 2 +- internal/backend/bn254/cs/solution.go | 2 +- internal/backend/bn254/groth16/setup.go | 2 +- internal/backend/bw6-633/cs/r1cs.go | 2 +- internal/backend/bw6-633/cs/r1cs_sparse.go | 2 +- internal/backend/bw6-633/cs/solution.go | 2 +- internal/backend/bw6-633/groth16/setup.go | 2 +- internal/backend/bw6-761/cs/r1cs.go | 2 +- internal/backend/bw6-761/cs/r1cs_sparse.go | 2 +- internal/backend/bw6-761/cs/solution.go | 2 +- internal/backend/bw6-761/groth16/setup.go | 2 +- internal/backend/compiled/html.go | 137 ----------------- internal/backend/compiled/r1cs.go | 26 ---- internal/backend/compiled/r1cs_sparse.go | 26 ---- .../template/representations/r1cs.go.tmpl | 2 +- .../representations/r1cs.sparse.go.tmpl | 2 +- .../template/representations/solution.go.tmpl | 2 +- .../zkpschemes/groth16/groth16.setup.go.tmpl | 2 +- internal/utils/circuit.go | 66 -------- test/assert.go | 5 +- test/engine.go | 58 ++++++- 55 files changed, 443 insertions(+), 641 deletions(-) create mode 100644 frontend/compiled/cs.go rename internal/backend/compiled/cs.go => frontend/compiled/hint.go (60%) rename {internal/backend => frontend}/compiled/log.go (100%) rename internal/backend/compiled/r1c.go => frontend/compiled/r1cs.go (81%) rename internal/backend/compiled/r1c_sparse.go => frontend/compiled/r1cs_sparse.go (84%) rename {internal/backend => frontend}/compiled/term.go (81%) rename {internal/backend => frontend}/compiled/variable.go (67%) create mode 100644 frontend/cs/builder.go delete mode 100644 frontend/cs/cs.go delete mode 100644 internal/backend/compiled/html.go delete mode 100644 internal/backend/compiled/r1cs.go delete mode 100644 internal/backend/compiled/r1cs_sparse.go diff --git a/backend/hint/builtin.go b/backend/hint/builtin.go index 451301c000..7840c74a2d 100644 --- a/backend/hint/builtin.go +++ b/backend/hint/builtin.go @@ -20,6 +20,8 @@ func init() { }) } +// TODO FIXME these may be redefined easily by an external package + // The package provides the following built-in hint functions. All built-in hint // functions are registered in the registry. var ( diff --git a/frontend/api.go b/frontend/api.go index 95c4298bc8..d3b7f5e09e 100644 --- a/frontend/api.go +++ b/frontend/api.go @@ -52,6 +52,7 @@ type API interface { // --------------------------------------------------------------------------------------------- // Bit operations + // TODO @gbotrel move bit operations in std/math/bits // ToBinary unpacks a Variable in binary, // n is the number of bits to select (starting from lsb) diff --git a/frontend/ccs.go b/frontend/ccs.go index d3659cc5c5..c7d91ca62a 100644 --- a/frontend/ccs.go +++ b/frontend/ccs.go @@ -20,8 +20,8 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" ) // CompiledConstraintSystem interface that a compiled (=typed, and correctly routed) diff --git a/frontend/compiled/cs.go b/frontend/compiled/cs.go new file mode 100644 index 0000000000..5a8deaf6ee --- /dev/null +++ b/frontend/compiled/cs.go @@ -0,0 +1,143 @@ +package compiled + +import ( + "fmt" + "io" + "strings" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/debug" + "github.com/consensys/gnark/frontend/schema" + "github.com/consensys/gnark/internal/utils" +) + +// ConstraintSystem contains common element between R1CS and ConstraintSystem +type ConstraintSystem struct { + + // schema of the circuit + Schema *schema.Schema + + // number of wires + NbInternalVariables int + NbPublicVariables int + NbSecretVariables int + + // input wires names + Public, Secret []string + + // logs (added with cs.Println, resolved when solver sets a value to a wire) + Logs []LogEntry + + // debug info contains stack trace (including line number) of a call to a cs.API that + // results in an unsolved constraint + DebugInfo []LogEntry + + // maps constraint id to debugInfo id + // several constraints may point to the same debug info + MDebug map[int]int + + Counters []Counter // TODO @gbotrel no point in serializing these + + // maps wire id to hint + // a wire may point to at most one hint + MHints map[int]*Hint + + // each level contains independent constraints and can be parallelized + // it is guaranteed that all dependncies for constraints in a level l are solved + // in previous levels + Levels [][]int + + CurveID ecc.ID +} + +// GetNbVariables return number of internal, secret and public variables +func (cs *ConstraintSystem) GetNbVariables() (internal, secret, public int) { + return cs.NbInternalVariables, cs.NbSecretVariables, cs.NbPublicVariables +} + +// FrSize panics +func (cs *ConstraintSystem) FrSize() int { panic("not implemented") } + +// GetNbCoefficients panics +func (cs *ConstraintSystem) GetNbCoefficients() int { panic("not implemented") } + +// // CurveID returns ecc.UNKNOWN +// func (cs *CS) CurveID() ecc.ID { return ecc.UNKNOWN } + +// WriteTo panics +func (cs *ConstraintSystem) WriteTo(w io.Writer) (n int64, err error) { panic("not implemented") } + +// ReadFrom panics +func (cs *ConstraintSystem) ReadFrom(r io.Reader) (n int64, err error) { panic("not implemented") } + +// ToHTML panics +func (cs *ConstraintSystem) ToHTML(w io.Writer) error { panic("not implemtened") } + +// GetCounters return the collected constraint counters, if any +func (cs *ConstraintSystem) GetCounters() []Counter { return cs.Counters } + +func (cs *ConstraintSystem) GetSchema() *schema.Schema { return cs.Schema } + +func (cs *ConstraintSystem) GetConstraints() [][]string { panic("not implemented") } + +// Counter contains measurements of useful statistics between two Tag +type Counter struct { + From, To string + NbVariables int + NbConstraints int + CurveID ecc.ID + BackendID backend.ID +} + +func (c Counter) String() string { + return fmt.Sprintf("%s[%s] %s - %s: %d variables, %d constraints", c.BackendID, c.CurveID, c.From, c.To, c.NbVariables, c.NbConstraints) +} + +func (cs *ConstraintSystem) Curve() ecc.ID { + return cs.CurveID +} + +func (cs *ConstraintSystem) AddDebugInfo(errName string, i ...interface{}) int { + + var l LogEntry + + const minLogSize = 500 + var sbb strings.Builder + sbb.Grow(minLogSize) + sbb.WriteString("[") + sbb.WriteString(errName) + sbb.WriteString("] ") + + for _, _i := range i { + switch v := _i.(type) { + case Variable: + if len(v.LinExp) > 1 { + sbb.WriteString("(") + } + l.WriteVariable(v, &sbb) + if len(v.LinExp) > 1 { + sbb.WriteString(")") + } + case string: + sbb.WriteString(v) + case Term: + l.WriteTerm(v, &sbb) + default: + _v := utils.FromInterface(v) + sbb.WriteString(_v.String()) + } + } + sbb.WriteByte('\n') + debug.WriteStack(&sbb) + l.Format = sbb.String() + + cs.DebugInfo = append(cs.DebugInfo, l) + + return len(cs.DebugInfo) - 1 +} + +// bitLen returns the number of bits needed to represent a fr.Element +func (cs *ConstraintSystem) BitLen() int { + return cs.CurveID.Info().Fr.Bits +} diff --git a/internal/backend/compiled/cs.go b/frontend/compiled/hint.go similarity index 60% rename from internal/backend/compiled/cs.go rename to frontend/compiled/hint.go index b91dbba9f4..8aeee12208 100644 --- a/internal/backend/compiled/cs.go +++ b/frontend/compiled/hint.go @@ -2,49 +2,13 @@ package compiled import ( "fmt" - "io" "math/big" "reflect" - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/hint" - "github.com/consensys/gnark/frontend/schema" "github.com/fxamacker/cbor/v2" ) -// CS contains common element between R1CS and CS -type CS struct { - - // number of wires - NbInternalVariables int - NbPublicVariables int - NbSecretVariables int - - // logs (added with cs.Println, resolved when solver sets a value to a wire) - Logs []LogEntry - // debug info contains stack trace (including line number) of a call to a cs.API that - // results in an unsolved constraint - DebugInfo []LogEntry - - // maps constraint id to debugInfo id - // several constraints may point to the same debug info - MDebug map[int]int - - Counters []Counter // TODO @gbotrel no point in serializing these - // maps wire id to hint - - // a wire may point to at most one hint - MHints map[int]*Hint - - Schema *schema.Schema - - // each level contains independent constraints and can be parallelized - // it is guaranteed that all dependncies for constraints in a level l are solved - // in previous levels - Levels [][]int -} - // Hint represents a solver hint // it enables the solver to compute a Wire with a function provided at solving time // using pre-defined inputs @@ -173,46 +137,3 @@ func (h *Hint) UnmarshalCBOR(b []byte) error { h.Wires = v.Wires return nil } - -// GetNbVariables return number of internal, secret and public variables -func (cs *CS) GetNbVariables() (internal, secret, public int) { - return cs.NbInternalVariables, cs.NbSecretVariables, cs.NbPublicVariables -} - -// FrSize panics -func (cs *CS) FrSize() int { panic("not implemented") } - -// GetNbCoefficients panics -func (cs *CS) GetNbCoefficients() int { panic("not implemented") } - -// CurveID returns ecc.UNKNOWN -func (cs *CS) CurveID() ecc.ID { return ecc.UNKNOWN } - -// WriteTo panics -func (cs *CS) WriteTo(w io.Writer) (n int64, err error) { panic("not implemented") } - -// ReadFrom panics -func (cs *CS) ReadFrom(r io.Reader) (n int64, err error) { panic("not implemented") } - -// ToHTML panics -func (cs *CS) ToHTML(w io.Writer) error { panic("not implemtened") } - -// GetCounters return the collected constraint counters, if any -func (cs *CS) GetCounters() []Counter { return cs.Counters } - -func (cs *CS) GetSchema() *schema.Schema { return cs.Schema } - -func (cs *CS) GetConstraints() [][]string { panic("not implemented") } - -// Counter contains measurements of useful statistics between two Tag -type Counter struct { - From, To string - NbVariables int - NbConstraints int - CurveID ecc.ID - BackendID backend.ID -} - -func (c Counter) String() string { - return fmt.Sprintf("%s[%s] %s - %s: %d variables, %d constraints", c.BackendID, c.CurveID, c.From, c.To, c.NbVariables, c.NbConstraints) -} diff --git a/internal/backend/compiled/log.go b/frontend/compiled/log.go similarity index 100% rename from internal/backend/compiled/log.go rename to frontend/compiled/log.go diff --git a/internal/backend/compiled/r1c.go b/frontend/compiled/r1cs.go similarity index 81% rename from internal/backend/compiled/r1c.go rename to frontend/compiled/r1cs.go index a2052de1d0..f2645385b7 100644 --- a/internal/backend/compiled/r1c.go +++ b/frontend/compiled/r1cs.go @@ -19,6 +19,17 @@ import ( "strings" ) +// R1CS decsribes a set of R1C constraint +type R1CS struct { + ConstraintSystem + Constraints []R1C +} + +// GetNbConstraints returns the number of constraints +func (r1cs *R1CS) GetNbConstraints() int { + return len(r1cs.Constraints) +} + // R1C used to compute the wires type R1C struct { L, R, O Variable diff --git a/internal/backend/compiled/r1c_sparse.go b/frontend/compiled/r1cs_sparse.go similarity index 84% rename from internal/backend/compiled/r1c_sparse.go rename to frontend/compiled/r1cs_sparse.go index c7adee2a06..20f554d6dd 100644 --- a/internal/backend/compiled/r1c_sparse.go +++ b/frontend/compiled/r1cs_sparse.go @@ -19,6 +19,17 @@ import ( "strings" ) +// R1CS decsribes a set of SparseR1C constraint +type SparseR1CS struct { + ConstraintSystem + Constraints []SparseR1C +} + +// GetNbConstraints returns the number of constraints +func (cs *SparseR1CS) GetNbConstraints() int { + return len(cs.Constraints) +} + // SparseR1C used to compute the wires // L+R+M[0]M[1]+O+k=0 // if a Term is zero, it means the field doesn't exist (ex M=[0,0] means there is no multiplicative term) diff --git a/internal/backend/compiled/term.go b/frontend/compiled/term.go similarity index 81% rename from internal/backend/compiled/term.go rename to frontend/compiled/term.go index 57966339b9..3d7b2d7380 100644 --- a/internal/backend/compiled/term.go +++ b/frontend/compiled/term.go @@ -22,13 +22,14 @@ import ( "github.com/consensys/gnark/frontend/schema" ) -// Term lightweight version of a term, no pointers -// first 4 bits are reserved -// next 30 bits represented the coefficient idx (in r1cs.Coefficients) by which the wire is multiplied -// next 30 bits represent the constraint used to compute the wire -// if we support more than 1 billion constraints, this breaks (not so soon.) +// Term lightweight version of a term, no pointers. A term packs wireID, coeffID, visibility and +// some metadata associated with the term, in a uint64. +// note: if we support more than 1 billion constraints, this breaks (not so soon.) type Term uint64 +// A linear expression is a linear combination of Term +type LinearExpression []Term + // ids of the coefficients with simple values in any cs.coeffs slice. const ( CoeffIdZero = 0 @@ -178,3 +179,41 @@ func (t Term) string(sbb *strings.Builder, coeffs []big.Int) { } sbb.WriteString(strconv.Itoa(t.WireID())) } + +// Len return the lenght of the Variable (implements Sort interface) +func (v LinearExpression) Len() int { + return len(v) +} + +// Equals returns true if both SORTED expressions are the same +// +// pre conditions: l and o are sorted +func (v LinearExpression) Equal(o LinearExpression) bool { + if len(v) != len(o) { + return false + } + if (v == nil) != (o == nil) { + return false + } + for i := 0; i < len(v); i++ { + if v[i] != o[i] { + return false + } + } + return true +} + +// Swap swaps terms in the Variable (implements Sort interface) +func (v LinearExpression) Swap(i, j int) { + v[i], v[j] = v[j], v[i] +} + +// Less returns true if variableID for term at i is less than variableID for term at j (implements Sort interface) +func (v LinearExpression) Less(i, j int) bool { + _, iID, iVis := v[i].Unpack() + _, jID, jVis := v[j].Unpack() + if iVis == jVis { + return iID < jID + } + return iVis > jVis +} diff --git a/internal/backend/compiled/variable.go b/frontend/compiled/variable.go similarity index 67% rename from internal/backend/compiled/variable.go rename to frontend/compiled/variable.go index 3c6d499055..41085ba9de 100644 --- a/internal/backend/compiled/variable.go +++ b/frontend/compiled/variable.go @@ -25,8 +25,6 @@ import ( // errNoValue triggered when trying to access a variable that was not allocated var errNoValue = errors.New("can't determine API input value") -type LinearExpression []Term - // Variable represent a linear expression of wires type Variable struct { LinExp LinearExpression @@ -42,44 +40,6 @@ func (v Variable) Clone() Variable { return res } -// Len return the lenght of the Variable (implements Sort interface) -func (v LinearExpression) Len() int { - return len(v) -} - -// Equals returns true if both SORTED expressions are the same -// -// pre conditions: l and o are sorted -func (v LinearExpression) Equal(o LinearExpression) bool { - if len(v) != len(o) { - return false - } - if (v == nil) != (o == nil) { - return false - } - for i := 0; i < len(v); i++ { - if v[i] != o[i] { - return false - } - } - return true -} - -// Swap swaps terms in the Variable (implements Sort interface) -func (v LinearExpression) Swap(i, j int) { - v[i], v[j] = v[j], v[i] -} - -// Less returns true if variableID for term at i is less than variableID for term at j (implements Sort interface) -func (v LinearExpression) Less(i, j int) bool { - _, iID, iVis := v[i].Unpack() - _, jID, jVis := v[j].Unpack() - if iVis == jVis { - return iID < jID - } - return iVis > jVis -} - func (v Variable) string(sbb *strings.Builder, coeffs []big.Int) { for i := 0; i < len(v.LinExp); i++ { v.LinExp[i].string(sbb, coeffs) diff --git a/frontend/cs/builder.go b/frontend/cs/builder.go new file mode 100644 index 0000000000..64a93832e6 --- /dev/null +++ b/frontend/cs/builder.go @@ -0,0 +1,66 @@ +package cs + +import ( + "math/big" +) + +// Builder helps build a constraint system but need not be serialized after compilation +type Builder struct { + // Coefficients in the constraints + Coeffs []big.Int // list of unique coefficients. + CoeffsIDsLarge map[string]int // map to check existence of a coefficient (key = coeff.Bytes()) + CoeffsIDsInt64 map[int64]int // map to check existence of a coefficient (key = int64 value) + + // map for recording boolean constrained variables (to not constrain them twice) + MTBooleans map[int]struct{} +} + +func NewBuilder() Builder { + return Builder{ + Coeffs: make([]big.Int, 4), + CoeffsIDsLarge: make(map[string]int), + CoeffsIDsInt64: make(map[int64]int, 4), + MTBooleans: make(map[int]struct{}), + } +} + +// CoeffID tries to fetch the entry where b is if it exits, otherwise appends b to +// the list of Coeffs and returns the corresponding entry +func (b *Builder) CoeffID(v *big.Int) int { + + // if the coeff is a int64 we have a fast path. + if v.IsInt64() { + return b.coeffID64(v.Int64()) + } + + // GobEncode is 3x faster than b.Text(16). Slightly slower than Bytes, but Bytes return the same + // thing for -x and x . + bKey, _ := v.GobEncode() + key := string(bKey) + + // if the coeff is already stored, fetch its ID from the cs.CoeffsIDs map + if idx, ok := b.CoeffsIDsLarge[key]; ok { + return idx + } + + // else add it in the cs.Coeffs map and update the cs.CoeffsIDs map + var bCopy big.Int + bCopy.Set(v) + resID := len(b.Coeffs) + b.Coeffs = append(b.Coeffs, bCopy) + b.CoeffsIDsLarge[key] = resID + return resID +} + +func (b *Builder) coeffID64(v int64) int { + if resID, ok := b.CoeffsIDsInt64[v]; ok { + return resID + } else { + var bCopy big.Int + bCopy.SetInt64(v) + resID := len(b.Coeffs) + b.Coeffs = append(b.Coeffs, bCopy) + b.CoeffsIDsInt64[v] = resID + return resID + } +} diff --git a/frontend/cs/cs.go b/frontend/cs/cs.go deleted file mode 100644 index 3bad64cd4d..0000000000 --- a/frontend/cs/cs.go +++ /dev/null @@ -1,135 +0,0 @@ -/* -Copyright © 2020 ConsenSys - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package cs - -import ( - "math/big" - "strings" - - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/debug" - "github.com/consensys/gnark/internal/backend/compiled" - "github.com/consensys/gnark/internal/utils" -) - -// ConstraintSystem contains the parts common to plonk and Groth16 -type ConstraintSystem struct { - compiled.CS - - // input wires - Public, Secret []string - - CurveID ecc.ID - // BackendID backend.ID - - // Coefficients in the constraints - Coeffs []big.Int // list of unique coefficients. - CoeffsIDsLarge map[string]int // map to check existence of a coefficient (key = coeff.Bytes()) - CoeffsIDsInt64 map[int64]int // map to check existence of a coefficient (key = int64 value) - - // map for recording boolean constrained variables (to not constrain them twice) - MTBooleans map[int]struct{} -} - -func (cs *ConstraintSystem) Curve() ecc.ID { - return cs.CurveID -} - -func (cs *ConstraintSystem) CoeffID64(v int64) int { - if resID, ok := cs.CoeffsIDsInt64[v]; ok { - return resID - } else { - var bCopy big.Int - bCopy.SetInt64(v) - resID := len(cs.Coeffs) - cs.Coeffs = append(cs.Coeffs, bCopy) - cs.CoeffsIDsInt64[v] = resID - return resID - } -} - -// CoeffID tries to fetch the entry where b is if it exits, otherwise appends b to -// the list of Coeffs and returns the corresponding entry -func (cs *ConstraintSystem) CoeffID(b *big.Int) int { - - // if the coeff is a int64 we have a fast path. - if b.IsInt64() { - return cs.CoeffID64(b.Int64()) - } - - // GobEncode is 3x faster than b.Text(16). Slightly slower than Bytes, but Bytes return the same - // thing for -x and x . - bKey, _ := b.GobEncode() - key := string(bKey) - - // if the coeff is already stored, fetch its ID from the cs.CoeffsIDs map - if idx, ok := cs.CoeffsIDsLarge[key]; ok { - return idx - } - - // else add it in the cs.Coeffs map and update the cs.CoeffsIDs map - var bCopy big.Int - bCopy.Set(b) - resID := len(cs.Coeffs) - cs.Coeffs = append(cs.Coeffs, bCopy) - cs.CoeffsIDsLarge[key] = resID - return resID -} - -func (cs *ConstraintSystem) AddDebugInfo(errName string, i ...interface{}) int { - - var l compiled.LogEntry - - const minLogSize = 500 - var sbb strings.Builder - sbb.Grow(minLogSize) - sbb.WriteString("[") - sbb.WriteString(errName) - sbb.WriteString("] ") - - for _, _i := range i { - switch v := _i.(type) { - case compiled.Variable: - if len(v.LinExp) > 1 { - sbb.WriteString("(") - } - l.WriteVariable(v, &sbb) - if len(v.LinExp) > 1 { - sbb.WriteString(")") - } - case string: - sbb.WriteString(v) - case compiled.Term: - l.WriteTerm(v, &sbb) - default: - _v := utils.FromInterface(v) - sbb.WriteString(_v.String()) - } - } - sbb.WriteByte('\n') - debug.WriteStack(&sbb) - l.Format = sbb.String() - - cs.DebugInfo = append(cs.DebugInfo, l) - - return len(cs.DebugInfo) - 1 -} - -// bitLen returns the number of bits needed to represent a fr.Element -func (cs *ConstraintSystem) BitLen() int { - return cs.CurveID.Info().Fr.Bits -} diff --git a/frontend/cs/plonk/api.go b/frontend/cs/plonk/api.go index 385d483efe..d1c75b5764 100644 --- a/frontend/cs/plonk/api.go +++ b/frontend/cs/plonk/api.go @@ -28,8 +28,8 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/utils" ) @@ -46,7 +46,7 @@ func (system *sparseR1CS) Add(i1, i2 frontend.Variable, in ...frontend.Variable) return system.splitSum(vars[0], vars[1:]) } cl, _, _ := vars[0].Unpack() - kID := system.CoeffID(&k) + kID := system.builder.CoeffID(&k) o := system.newInternalVariable() system.addPlonkConstraint(vars[0], system.zero(), o, cl, compiled.CoeffIdZero, compiled.CoeffIdZero, compiled.CoeffIdZero, compiled.CoeffIdMinusOne, kID) return system.splitSum(o, vars[1:]) @@ -80,9 +80,9 @@ func (system *sparseR1CS) Neg(i1 frontend.Variable) frontend.Variable { v := i1.(compiled.Term) c, _, _ := v.Unpack() var coef big.Int - coef.Set(&system.Coeffs[c]) + coef.Set(&system.builder.Coeffs[c]) coef.Neg(&coef) - c = system.CoeffID(&coef) + c = system.builder.CoeffID(&coef) v.SetCoeffID(c) return v } @@ -104,9 +104,9 @@ func (system *sparseR1CS) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) func (system *sparseR1CS) mulConstant(t compiled.Term, m *big.Int) compiled.Term { var coef big.Int cid, _, _ := t.Unpack() - coef.Set(&system.Coeffs[cid]) + coef.Set(&system.builder.Coeffs[cid]) coef.Mul(m, &coef).Mod(&coef, system.CurveID.Info().Fr.Modulus()) - cid = system.CoeffID(&coef) + cid = system.builder.CoeffID(&coef) t.SetCoeffID(cid) return t } @@ -274,7 +274,7 @@ func (system *sparseR1CS) Xor(a, b frontend.Variable) frontend.Variable { _b := utils.FromInterface(b) one := big.NewInt(1) _b.Lsh(&_b, 1).Sub(&_b, one) - idl := system.CoeffID(&_b) + idl := system.builder.CoeffID(&_b) system.addPlonkConstraint(l, r, res, idl, compiled.CoeffIdZero, compiled.CoeffIdZero, compiled.CoeffIdZero, compiled.CoeffIdOne, compiled.CoeffIdZero) return res } @@ -314,7 +314,7 @@ func (system *sparseR1CS) Or(a, b frontend.Variable) frontend.Variable { one := big.NewInt(1) _b.Sub(&_b, one) - idl := system.CoeffID(&_b) + idl := system.builder.CoeffID(&_b) system.addPlonkConstraint(l, r, res, idl, compiled.CoeffIdZero, compiled.CoeffIdZero, compiled.CoeffIdZero, compiled.CoeffIdOne, compiled.CoeffIdZero) return res } diff --git a/frontend/cs/plonk/assertions.go b/frontend/cs/plonk/assertions.go index efe58e8627..7657d64d7d 100644 --- a/frontend/cs/plonk/assertions.go +++ b/frontend/cs/plonk/assertions.go @@ -21,7 +21,7 @@ import ( "math/big" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/internal/backend/compiled" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/internal/utils" ) @@ -45,7 +45,7 @@ func (system *sparseR1CS) AssertIsEqual(i1, i2 frontend.Variable) { k := utils.FromInterface(i2) debug := system.AddDebugInfo("assertIsEqual", l, "+", i2, " == 0") k.Neg(&k) - _k := system.CoeffID(&k) + _k := system.builder.CoeffID(&k) system.addPlonkConstraint(l, system.zero(), system.zero(), lc, compiled.CoeffIdZero, compiled.CoeffIdZero, compiled.CoeffIdZero, compiled.CoeffIdZero, _k, debug) return } @@ -77,12 +77,12 @@ func (system *sparseR1CS) AssertIsBoolean(i1 frontend.Variable) { return } system.markBoolean(t) - system.MTBooleans[int(t)] = struct{}{} + system.builder.MTBooleans[int(t)] = struct{}{} debug := system.AddDebugInfo("assertIsBoolean", t, " == (0|1)") cID, _, _ := t.Unpack() var mCoef big.Int - mCoef.Neg(&system.Coeffs[cID]) - mcID := system.CoeffID(&mCoef) + mCoef.Neg(&system.builder.Coeffs[cID]) + mcID := system.builder.CoeffID(&mCoef) system.addPlonkConstraint(t, t, system.zero(), cID, compiled.CoeffIdZero, mcID, cID, compiled.CoeffIdZero, compiled.CoeffIdZero, debug) } diff --git a/frontend/cs/plonk/conversion.go b/frontend/cs/plonk/conversion.go index 1f2e8d564e..1710a3f4ed 100644 --- a/frontend/cs/plonk/conversion.go +++ b/frontend/cs/plonk/conversion.go @@ -19,6 +19,7 @@ package plonk import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" bls12377r1cs "github.com/consensys/gnark/internal/backend/bls12-377/cs" bls12381r1cs "github.com/consensys/gnark/internal/backend/bls12-381/cs" @@ -26,14 +27,13 @@ import ( bn254r1cs "github.com/consensys/gnark/internal/backend/bn254/cs" bw6633r1cs "github.com/consensys/gnark/internal/backend/bw6-633/cs" bw6761r1cs "github.com/consensys/gnark/internal/backend/bw6-761/cs" - "github.com/consensys/gnark/internal/backend/compiled" ) func (cs *sparseR1CS) Compile() (frontend.CompiledConstraintSystem, error) { res := compiled.SparseR1CS{ - CS: cs.CS, - Constraints: cs.Constraints, + ConstraintSystem: cs.ConstraintSystem, + Constraints: cs.Constraints, } res.NbPublicVariables = len(cs.Public) res.NbSecretVariables = len(cs.Secret) @@ -121,17 +121,17 @@ HINTLOOP: switch cs.CurveID { case ecc.BLS12_377: - return bls12377r1cs.NewSparseR1CS(res, cs.Coeffs), nil + return bls12377r1cs.NewSparseR1CS(res, cs.builder.Coeffs), nil case ecc.BLS12_381: - return bls12381r1cs.NewSparseR1CS(res, cs.Coeffs), nil + return bls12381r1cs.NewSparseR1CS(res, cs.builder.Coeffs), nil case ecc.BN254: - return bn254r1cs.NewSparseR1CS(res, cs.Coeffs), nil + return bn254r1cs.NewSparseR1CS(res, cs.builder.Coeffs), nil case ecc.BW6_761: - return bw6761r1cs.NewSparseR1CS(res, cs.Coeffs), nil + return bw6761r1cs.NewSparseR1CS(res, cs.builder.Coeffs), nil case ecc.BLS24_315: - return bls24315r1cs.NewSparseR1CS(res, cs.Coeffs), nil + return bls24315r1cs.NewSparseR1CS(res, cs.builder.Coeffs), nil case ecc.BW6_633: - return bw6633r1cs.NewSparseR1CS(res, cs.Coeffs), nil + return bw6633r1cs.NewSparseR1CS(res, cs.builder.Coeffs), nil default: panic("unknown curveID") } diff --git a/frontend/cs/plonk/sparse_r1cs.go b/frontend/cs/plonk/sparse_r1cs.go index 2e4beac548..01afe21a91 100644 --- a/frontend/cs/plonk/sparse_r1cs.go +++ b/frontend/cs/plonk/sparse_r1cs.go @@ -26,9 +26,9 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/cs" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" ) func NewBuilder(curve ecc.ID) (frontend.Builder, error) { @@ -36,9 +36,10 @@ func NewBuilder(curve ecc.ID) (frontend.Builder, error) { } type sparseR1CS struct { - cs.ConstraintSystem - + compiled.ConstraintSystem Constraints []compiled.SparseR1C + + builder cs.Builder } // initialCapacity has quite some impact on frontend performance, especially on large circuits size @@ -49,30 +50,24 @@ func newSparseR1CS(curveID ecc.ID, initialCapacity ...int) *sparseR1CS { capacity = initialCapacity[0] } system := sparseR1CS{ - ConstraintSystem: cs.ConstraintSystem{ - - CS: compiled.CS{ - MDebug: make(map[int]int), - MHints: make(map[int]*compiled.Hint), - }, + ConstraintSystem: compiled.ConstraintSystem{ - Coeffs: make([]big.Int, 4), - CoeffsIDsLarge: make(map[string]int), - CoeffsIDsInt64: make(map[int64]int, 4), - MTBooleans: make(map[int]struct{}), + MDebug: make(map[int]int), + MHints: make(map[int]*compiled.Hint), }, Constraints: make([]compiled.SparseR1C, 0, capacity), + builder: cs.NewBuilder(), } - system.Coeffs[compiled.CoeffIdZero].SetInt64(0) - system.Coeffs[compiled.CoeffIdOne].SetInt64(1) - system.Coeffs[compiled.CoeffIdTwo].SetInt64(2) - system.Coeffs[compiled.CoeffIdMinusOne].SetInt64(-1) + system.builder.Coeffs[compiled.CoeffIdZero].SetInt64(0) + system.builder.Coeffs[compiled.CoeffIdOne].SetInt64(1) + system.builder.Coeffs[compiled.CoeffIdTwo].SetInt64(2) + system.builder.Coeffs[compiled.CoeffIdMinusOne].SetInt64(-1) - system.CoeffsIDsInt64[0] = compiled.CoeffIdZero - system.CoeffsIDsInt64[1] = compiled.CoeffIdOne - system.CoeffsIDsInt64[2] = compiled.CoeffIdTwo - system.CoeffsIDsInt64[-1] = compiled.CoeffIdMinusOne + system.builder.CoeffsIDsInt64[0] = compiled.CoeffIdZero + system.builder.CoeffsIDsInt64[1] = compiled.CoeffIdOne + system.builder.CoeffsIDsInt64[2] = compiled.CoeffIdTwo + system.builder.CoeffsIDsInt64[-1] = compiled.CoeffIdMinusOne // system.public.variables = make([]Variable, 0) // system.secret.variables = make([]Variable, 0) @@ -144,9 +139,9 @@ func (system *sparseR1CS) reduce(l compiled.LinearExpression) compiled.LinearExp ccID, cvID, cVis := l[i].Unpack() if pVis == cVis && pvID == cvID { // we have redundancy - c.Add(&system.Coeffs[pcID], &system.Coeffs[ccID]) + c.Add(&system.builder.Coeffs[pcID], &system.builder.Coeffs[ccID]) c.Mod(c, mod) - l[i-1].SetCoeffID(system.CoeffID(c)) + l[i-1].SetCoeffID(system.builder.CoeffID(c)) l = append(l[:i], l[i+1:]...) i-- } @@ -162,13 +157,13 @@ func (system *sparseR1CS) zero() compiled.Term { // returns true if a variable is already boolean func (system *sparseR1CS) isBoolean(t compiled.Term) bool { - _, ok := system.MTBooleans[int(t)] + _, ok := system.builder.MTBooleans[int(t)] return ok } // markBoolean records t in the map to not boolean constrain it twice func (system *sparseR1CS) markBoolean(t compiled.Term) { - system.MTBooleans[int(t)] = struct{}{} + system.builder.MTBooleans[int(t)] = struct{}{} } // checkVariables perform post compilation checks on the Variables diff --git a/frontend/cs/r1cs/api.go b/frontend/cs/r1cs/api.go index e6f06ce614..4045e694ce 100644 --- a/frontend/cs/r1cs/api.go +++ b/frontend/cs/r1cs/api.go @@ -28,8 +28,8 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/utils" ) @@ -156,10 +156,10 @@ func (system *r1CS) mulConstant(v1, constant compiled.Variable) compiled.Variabl case compiled.CoeffIdTwo: newCoeff.Add(lambda, lambda) default: - coeff := system.Coeffs[cID] + coeff := system.builder.Coeffs[cID] newCoeff.Mul(&coeff, lambda) } - res.LinExp[i] = compiled.Pack(vID, system.CoeffID(&newCoeff), visibility) + res.LinExp[i] = compiled.Pack(vID, system.builder.CoeffID(&newCoeff), visibility) } t := false res.IsBoolean = &t @@ -796,8 +796,8 @@ func (system *r1CS) negateLinExp(l []compiled.Term) []compiled.Term { var lambda big.Int for i, t := range l { cID, vID, visibility := t.Unpack() - lambda.Neg(&system.Coeffs[cID]) - cID = system.CoeffID(&lambda) + lambda.Neg(&system.builder.Coeffs[cID]) + cID = system.builder.CoeffID(&lambda) res[i] = compiled.Pack(vID, cID, visibility) } return res diff --git a/frontend/cs/r1cs/assertions.go b/frontend/cs/r1cs/assertions.go index aba87440c0..ed897456a6 100644 --- a/frontend/cs/r1cs/assertions.go +++ b/frontend/cs/r1cs/assertions.go @@ -21,7 +21,7 @@ import ( "math/big" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/internal/backend/compiled" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/internal/utils" ) diff --git a/frontend/cs/r1cs/conversion.go b/frontend/cs/r1cs/conversion.go index c71d520e9d..3d9be7da55 100644 --- a/frontend/cs/r1cs/conversion.go +++ b/frontend/cs/r1cs/conversion.go @@ -19,6 +19,7 @@ package r1cs import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" bls12377r1cs "github.com/consensys/gnark/internal/backend/bls12-377/cs" bls12381r1cs "github.com/consensys/gnark/internal/backend/bls12-381/cs" @@ -26,7 +27,6 @@ import ( bn254r1cs "github.com/consensys/gnark/internal/backend/bn254/cs" bw6633r1cs "github.com/consensys/gnark/internal/backend/bw6-633/cs" bw6761r1cs "github.com/consensys/gnark/internal/backend/bw6-761/cs" - "github.com/consensys/gnark/internal/backend/compiled" ) // Compile constructs a rank-1 constraint sytem @@ -36,8 +36,8 @@ func (cs *r1CS) Compile() (frontend.CompiledConstraintSystem, error) { // setting up the result res := compiled.R1CS{ - CS: cs.CS, - Constraints: cs.Constraints, + ConstraintSystem: cs.ConstraintSystem, + Constraints: cs.Constraints, } res.NbPublicVariables = len(cs.Public) res.NbSecretVariables = len(cs.Secret) @@ -133,17 +133,17 @@ HINTLOOP: switch cs.CurveID { case ecc.BLS12_377: - return bls12377r1cs.NewR1CS(res, cs.Coeffs), nil + return bls12377r1cs.NewR1CS(res, cs.builder.Coeffs), nil case ecc.BLS12_381: - return bls12381r1cs.NewR1CS(res, cs.Coeffs), nil + return bls12381r1cs.NewR1CS(res, cs.builder.Coeffs), nil case ecc.BN254: - return bn254r1cs.NewR1CS(res, cs.Coeffs), nil + return bn254r1cs.NewR1CS(res, cs.builder.Coeffs), nil case ecc.BW6_761: - return bw6761r1cs.NewR1CS(res, cs.Coeffs), nil + return bw6761r1cs.NewR1CS(res, cs.builder.Coeffs), nil case ecc.BW6_633: - return bw6633r1cs.NewR1CS(res, cs.Coeffs), nil + return bw6633r1cs.NewR1CS(res, cs.builder.Coeffs), nil case ecc.BLS24_315: - return bls24315r1cs.NewR1CS(res, cs.Coeffs), nil + return bls24315r1cs.NewR1CS(res, cs.builder.Coeffs), nil default: panic("not implemtented") } diff --git a/frontend/cs/r1cs/r1cs.go b/frontend/cs/r1cs/r1cs.go index eb22b890b8..ab82970581 100644 --- a/frontend/cs/r1cs/r1cs.go +++ b/frontend/cs/r1cs/r1cs.go @@ -26,9 +26,9 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/cs" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" ) func NewBuilder(curve ecc.ID) (frontend.Builder, error) { @@ -36,9 +36,10 @@ func NewBuilder(curve ecc.ID) (frontend.Builder, error) { } type r1CS struct { - cs.ConstraintSystem - + compiled.ConstraintSystem Constraints []compiled.R1C + + builder cs.Builder } // initialCapacity has quite some impact on frontend performance, especially on large circuits size @@ -49,35 +50,25 @@ func newR1CS(curveID ecc.ID, initialCapacity ...int) *r1CS { capacity = initialCapacity[0] } system := r1CS{ - ConstraintSystem: cs.ConstraintSystem{ + ConstraintSystem: compiled.ConstraintSystem{ - CS: compiled.CS{ - MDebug: make(map[int]int), - MHints: make(map[int]*compiled.Hint), - }, - - Coeffs: make([]big.Int, 4), - CoeffsIDsLarge: make(map[string]int), - CoeffsIDsInt64: make(map[int64]int, 4), + MDebug: make(map[int]int), + MHints: make(map[int]*compiled.Hint), }, Constraints: make([]compiled.R1C, 0, capacity), - - // Counters: make([]Counter, 0), + builder: cs.NewBuilder(), } - system.Coeffs[compiled.CoeffIdZero].SetInt64(0) - system.Coeffs[compiled.CoeffIdOne].SetInt64(1) - system.Coeffs[compiled.CoeffIdTwo].SetInt64(2) - system.Coeffs[compiled.CoeffIdMinusOne].SetInt64(-1) + system.builder.Coeffs[compiled.CoeffIdZero].SetInt64(0) + system.builder.Coeffs[compiled.CoeffIdOne].SetInt64(1) + system.builder.Coeffs[compiled.CoeffIdTwo].SetInt64(2) + system.builder.Coeffs[compiled.CoeffIdMinusOne].SetInt64(-1) - system.CoeffsIDsInt64[0] = compiled.CoeffIdZero - system.CoeffsIDsInt64[1] = compiled.CoeffIdOne - system.CoeffsIDsInt64[2] = compiled.CoeffIdTwo - system.CoeffsIDsInt64[-1] = compiled.CoeffIdMinusOne + system.builder.CoeffsIDsInt64[0] = compiled.CoeffIdZero + system.builder.CoeffsIDsInt64[1] = compiled.CoeffIdOne + system.builder.CoeffsIDsInt64[2] = compiled.CoeffIdTwo + system.builder.CoeffsIDsInt64[-1] = compiled.CoeffIdMinusOne - // system.public.variables = make([]Variable, 0) - // system.secret.variables = make([]Variable, 0) - // system.internal = make([]Variable, 0, capacity) system.Public = make([]string, 1) system.Secret = make([]string, 0) @@ -85,7 +76,6 @@ func newR1CS(curveID ecc.ID, initialCapacity ...int) *r1CS { system.Public[0] = "one" system.CurveID = curveID - // system.BackendID = backendID return &system } @@ -133,7 +123,7 @@ func (system *r1CS) constantValue(v compiled.Variable) *big.Int { if !v.IsConstant() { panic("can't get big.Int value on a non-constant variable") } - return new(big.Int).Set(&system.Coeffs[v.LinExp[0].CoeffID()]) + return new(big.Int).Set(&system.builder.Coeffs[v.LinExp[0].CoeffID()]) } func (system *r1CS) one() compiled.Variable { @@ -161,9 +151,9 @@ func (system *r1CS) reduce(l compiled.Variable) compiled.Variable { ccID, cvID, cVis := l.LinExp[i].Unpack() if pVis == cVis && pvID == cvID { // we have redundancy - c.Add(&system.Coeffs[pcID], &system.Coeffs[ccID]) + c.Add(&system.builder.Coeffs[pcID], &system.builder.Coeffs[ccID]) c.Mod(c, mod) - l.LinExp[i-1].SetCoeffID(system.CoeffID(c)) + l.LinExp[i-1].SetCoeffID(system.builder.CoeffID(c)) l.LinExp = append(l.LinExp[:i], l.LinExp[i+1:]...) i-- } @@ -202,7 +192,7 @@ func (system *r1CS) addConstraint(r1c compiled.R1C, debugID ...int) { // func (system *R1CSRefactor) setCoeff(v Variable, coeff *big.Int) Term { func (system *r1CS) setCoeff(v compiled.Term, coeff *big.Int) compiled.Term { _, vID, vVis := v.Unpack() - return compiled.Pack(vID, system.CoeffID(coeff), vVis) + return compiled.Pack(vID, system.builder.CoeffID(coeff), vVis) } // markBoolean marks the Variable as boolean and return true diff --git a/frontend/cs/r1cs/r1cs_test.go b/frontend/cs/r1cs/r1cs_test.go index 573bb2a0d9..d16bb98c3a 100644 --- a/frontend/cs/r1cs/r1cs_test.go +++ b/frontend/cs/r1cs/r1cs_test.go @@ -21,8 +21,8 @@ import ( "testing" "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" ) func TestQuickSort(t *testing.T) { diff --git a/internal/backend/bls12-377/cs/r1cs.go b/internal/backend/bls12-377/cs/r1cs.go index 1313d715e4..a73732f8e8 100644 --- a/internal/backend/bls12-377/cs/r1cs.go +++ b/internal/backend/bls12-377/cs/r1cs.go @@ -28,8 +28,8 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc" diff --git a/internal/backend/bls12-377/cs/r1cs_sparse.go b/internal/backend/bls12-377/cs/r1cs_sparse.go index 8c5d2c087b..c1d6a0e887 100644 --- a/internal/backend/bls12-377/cs/r1cs_sparse.go +++ b/internal/backend/bls12-377/cs/r1cs_sparse.go @@ -30,8 +30,8 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" diff --git a/internal/backend/bls12-377/cs/solution.go b/internal/backend/bls12-377/cs/solution.go index b8b7dd02bb..80c77b6a9f 100644 --- a/internal/backend/bls12-377/cs/solution.go +++ b/internal/backend/bls12-377/cs/solution.go @@ -24,8 +24,8 @@ import ( "sync/atomic" "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" diff --git a/internal/backend/bls12-377/groth16/setup.go b/internal/backend/bls12-377/groth16/setup.go index 21fe78713c..16e7b6926d 100644 --- a/internal/backend/bls12-377/groth16/setup.go +++ b/internal/backend/bls12-377/groth16/setup.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" - "github.com/consensys/gnark/internal/backend/compiled" + "github.com/consensys/gnark/frontend/compiled" "math/big" "math/bits" ) diff --git a/internal/backend/bls12-381/cs/r1cs.go b/internal/backend/bls12-381/cs/r1cs.go index 0a32652ad9..56f2b50f40 100644 --- a/internal/backend/bls12-381/cs/r1cs.go +++ b/internal/backend/bls12-381/cs/r1cs.go @@ -28,8 +28,8 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc" diff --git a/internal/backend/bls12-381/cs/r1cs_sparse.go b/internal/backend/bls12-381/cs/r1cs_sparse.go index 106e6eb0eb..a3ce43cd74 100644 --- a/internal/backend/bls12-381/cs/r1cs_sparse.go +++ b/internal/backend/bls12-381/cs/r1cs_sparse.go @@ -30,8 +30,8 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" diff --git a/internal/backend/bls12-381/cs/solution.go b/internal/backend/bls12-381/cs/solution.go index cb0964161d..3304585789 100644 --- a/internal/backend/bls12-381/cs/solution.go +++ b/internal/backend/bls12-381/cs/solution.go @@ -24,8 +24,8 @@ import ( "sync/atomic" "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" diff --git a/internal/backend/bls12-381/groth16/setup.go b/internal/backend/bls12-381/groth16/setup.go index d481f8ad0d..e30ddcff5c 100644 --- a/internal/backend/bls12-381/groth16/setup.go +++ b/internal/backend/bls12-381/groth16/setup.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" - "github.com/consensys/gnark/internal/backend/compiled" + "github.com/consensys/gnark/frontend/compiled" "math/big" "math/bits" ) diff --git a/internal/backend/bls24-315/cs/r1cs.go b/internal/backend/bls24-315/cs/r1cs.go index 964acf1f50..8fb821ccb3 100644 --- a/internal/backend/bls24-315/cs/r1cs.go +++ b/internal/backend/bls24-315/cs/r1cs.go @@ -28,8 +28,8 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc" diff --git a/internal/backend/bls24-315/cs/r1cs_sparse.go b/internal/backend/bls24-315/cs/r1cs_sparse.go index 9ecd27fb26..8451a8e5cf 100644 --- a/internal/backend/bls24-315/cs/r1cs_sparse.go +++ b/internal/backend/bls24-315/cs/r1cs_sparse.go @@ -30,8 +30,8 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" diff --git a/internal/backend/bls24-315/cs/solution.go b/internal/backend/bls24-315/cs/solution.go index 179b3490ec..6d10576504 100644 --- a/internal/backend/bls24-315/cs/solution.go +++ b/internal/backend/bls24-315/cs/solution.go @@ -24,8 +24,8 @@ import ( "sync/atomic" "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" diff --git a/internal/backend/bls24-315/groth16/setup.go b/internal/backend/bls24-315/groth16/setup.go index 596e7786fc..3a7dea1d87 100644 --- a/internal/backend/bls24-315/groth16/setup.go +++ b/internal/backend/bls24-315/groth16/setup.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" - "github.com/consensys/gnark/internal/backend/compiled" + "github.com/consensys/gnark/frontend/compiled" "math/big" "math/bits" ) diff --git a/internal/backend/bn254/cs/r1cs.go b/internal/backend/bn254/cs/r1cs.go index c44696bae1..27874fd9ed 100644 --- a/internal/backend/bn254/cs/r1cs.go +++ b/internal/backend/bn254/cs/r1cs.go @@ -28,8 +28,8 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc" diff --git a/internal/backend/bn254/cs/r1cs_sparse.go b/internal/backend/bn254/cs/r1cs_sparse.go index 8eab1f901f..0fa0d1523a 100644 --- a/internal/backend/bn254/cs/r1cs_sparse.go +++ b/internal/backend/bn254/cs/r1cs_sparse.go @@ -30,8 +30,8 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc/bn254/fr" diff --git a/internal/backend/bn254/cs/solution.go b/internal/backend/bn254/cs/solution.go index 2c526afcd1..b58d02cbe6 100644 --- a/internal/backend/bn254/cs/solution.go +++ b/internal/backend/bn254/cs/solution.go @@ -24,8 +24,8 @@ import ( "sync/atomic" "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark-crypto/ecc/bn254/fr" diff --git a/internal/backend/bn254/groth16/setup.go b/internal/backend/bn254/groth16/setup.go index 334461b25b..a5280b909b 100644 --- a/internal/backend/bn254/groth16/setup.go +++ b/internal/backend/bn254/groth16/setup.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" - "github.com/consensys/gnark/internal/backend/compiled" + "github.com/consensys/gnark/frontend/compiled" "math/big" "math/bits" ) diff --git a/internal/backend/bw6-633/cs/r1cs.go b/internal/backend/bw6-633/cs/r1cs.go index 7101b1a5b5..4f6afe073b 100644 --- a/internal/backend/bw6-633/cs/r1cs.go +++ b/internal/backend/bw6-633/cs/r1cs.go @@ -28,8 +28,8 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc" diff --git a/internal/backend/bw6-633/cs/r1cs_sparse.go b/internal/backend/bw6-633/cs/r1cs_sparse.go index 7cc2ccfad0..196842fa42 100644 --- a/internal/backend/bw6-633/cs/r1cs_sparse.go +++ b/internal/backend/bw6-633/cs/r1cs_sparse.go @@ -30,8 +30,8 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" diff --git a/internal/backend/bw6-633/cs/solution.go b/internal/backend/bw6-633/cs/solution.go index f4e08474ef..529820639f 100644 --- a/internal/backend/bw6-633/cs/solution.go +++ b/internal/backend/bw6-633/cs/solution.go @@ -24,8 +24,8 @@ import ( "sync/atomic" "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" diff --git a/internal/backend/bw6-633/groth16/setup.go b/internal/backend/bw6-633/groth16/setup.go index b26489abbc..9a6b230ed7 100644 --- a/internal/backend/bw6-633/groth16/setup.go +++ b/internal/backend/bw6-633/groth16/setup.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" - "github.com/consensys/gnark/internal/backend/compiled" + "github.com/consensys/gnark/frontend/compiled" "math/big" "math/bits" ) diff --git a/internal/backend/bw6-761/cs/r1cs.go b/internal/backend/bw6-761/cs/r1cs.go index 097d53581c..e774f7408c 100644 --- a/internal/backend/bw6-761/cs/r1cs.go +++ b/internal/backend/bw6-761/cs/r1cs.go @@ -28,8 +28,8 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc" diff --git a/internal/backend/bw6-761/cs/r1cs_sparse.go b/internal/backend/bw6-761/cs/r1cs_sparse.go index 8cc16bcf4c..68c01f59f3 100644 --- a/internal/backend/bw6-761/cs/r1cs_sparse.go +++ b/internal/backend/bw6-761/cs/r1cs_sparse.go @@ -30,8 +30,8 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" diff --git a/internal/backend/bw6-761/cs/solution.go b/internal/backend/bw6-761/cs/solution.go index 162ab02f8f..81a0600110 100644 --- a/internal/backend/bw6-761/cs/solution.go +++ b/internal/backend/bw6-761/cs/solution.go @@ -24,8 +24,8 @@ import ( "sync/atomic" "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" diff --git a/internal/backend/bw6-761/groth16/setup.go b/internal/backend/bw6-761/groth16/setup.go index 0e100d3319..06127edd0b 100644 --- a/internal/backend/bw6-761/groth16/setup.go +++ b/internal/backend/bw6-761/groth16/setup.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" - "github.com/consensys/gnark/internal/backend/compiled" + "github.com/consensys/gnark/frontend/compiled" "math/big" "math/bits" ) diff --git a/internal/backend/compiled/html.go b/internal/backend/compiled/html.go deleted file mode 100644 index b1335c396b..0000000000 --- a/internal/backend/compiled/html.go +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright 2020 ConsenSys AG -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package compiled - -const R1CSTemplate = ` - - - - - - - R1CS - - - - - - -
-

R1CS

- {{ $nbHints := len .MHints }} - {{ $nbConstraints := len .Constraints}} - {{.NbInternalVariables}} internal (includes {{$nbHints}} hints)
- {{.NbPublicVariables}} public
- {{.NbSecretVariables}} secret
- {{$nbConstraints}} constraints
-

L * R == O

-

-

-
- - - - - - - - - - - {{- range $i, $c := .Constraints}} - - - - - - - {{- end }} - -
#LRO
{{$i}} {{ toHTML $c.L $.Coefficients $.MHints}} {{ toHTML $c.R $.Coefficients $.MHints}} {{ toHTML $c.O $.Coefficients $.MHints}}
- - -` - -const SparseR1CSTemplate = ` - - - - - - - SparseR1CS - - - - - - -
-

SparseR1CS

- {{ $nbHints := len .MHints }} - {{ $nbConstraints := len .Constraints}} - - {{.NbInternalVariables}} internal (includes {{$nbHints}} hints)
- {{.NbPublicVariables}} public
- {{.NbSecretVariables}} secret
- {{$nbConstraints}} constraints
-

L + R + M0*M1 + O + k == 0

-

all variable id are offseted by 1 to match R1CS

-
- - - - - - - - - - - - - - - {{- range $i, $c := .Constraints}} - - - - - - - - - - {{- end }} - -
#LRM0M1Ok
{{$i}} {{ toHTML $c.L $.Coefficients $.MHints}} {{ toHTML $c.R $.Coefficients $.MHints}} {{ toHTML (index $c.M 0) $.Coefficients $.MHints}} {{ toHTML (index $c.M 1) $.Coefficients $.MHints}} {{ toHTML $c.O $.Coefficients $.MHints}} {{ toHTMLCoeff $c.K $.Coefficients }}
- - -` diff --git a/internal/backend/compiled/r1cs.go b/internal/backend/compiled/r1cs.go deleted file mode 100644 index 8831f6da01..0000000000 --- a/internal/backend/compiled/r1cs.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2020 ConsenSys AG -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package compiled - -// R1CS decsribes a set of R1C constraint -type R1CS struct { - CS - Constraints []R1C -} - -// GetNbConstraints returns the number of constraints -func (r1cs *R1CS) GetNbConstraints() int { - return len(r1cs.Constraints) -} diff --git a/internal/backend/compiled/r1cs_sparse.go b/internal/backend/compiled/r1cs_sparse.go deleted file mode 100644 index bcd59d5994..0000000000 --- a/internal/backend/compiled/r1cs_sparse.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2020 ConsenSys AG -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package compiled - -// R1CS decsribes a set of SparseR1C constraint -type SparseR1CS struct { - CS - Constraints []SparseR1C -} - -// GetNbConstraints returns the number of constraints -func (cs *SparseR1CS) GetNbConstraints() int { - return len(cs.Constraints) -} diff --git a/internal/generator/backend/template/representations/r1cs.go.tmpl b/internal/generator/backend/template/representations/r1cs.go.tmpl index ee3b99f846..45dd4fc4f6 100644 --- a/internal/generator/backend/template/representations/r1cs.go.tmpl +++ b/internal/generator/backend/template/representations/r1cs.go.tmpl @@ -9,7 +9,7 @@ import ( "github.com/fxamacker/cbor/v2" "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/internal/backend/compiled" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" diff --git a/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl b/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl index dbc9915d2c..0a220f3a4d 100644 --- a/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl +++ b/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl @@ -11,7 +11,7 @@ import ( "math" "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/internal/backend/compiled" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend/schema" "github.com/consensys/gnark/backend/witness" diff --git a/internal/generator/backend/template/representations/solution.go.tmpl b/internal/generator/backend/template/representations/solution.go.tmpl index a45ae12b63..4929438644 100644 --- a/internal/generator/backend/template/representations/solution.go.tmpl +++ b/internal/generator/backend/template/representations/solution.go.tmpl @@ -6,7 +6,7 @@ import ( "sync/atomic" "github.com/consensys/gnark/backend/hint" - "github.com/consensys/gnark/internal/backend/compiled" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/frontend/schema" diff --git a/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl index 14933f3931..5d25b59ff5 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl @@ -4,7 +4,7 @@ import ( {{ template "import_backend_cs" . }} {{ template "import_fft" . }} "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/internal/backend/compiled" + "github.com/consensys/gnark/frontend/compiled" "math/big" "math/bits" ) diff --git a/internal/utils/circuit.go b/internal/utils/circuit.go index 2c2c4f8a1d..d4b585bf78 100644 --- a/internal/utils/circuit.go +++ b/internal/utils/circuit.go @@ -1,67 +1 @@ package utils - -import ( - "fmt" - "reflect" - - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/schema" -) - -// ShallowClone clones given circuit -// this is actually a shallow copy → if the circuits contains maps or slices -// only the reference is copied. -func ShallowClone(circuit frontend.Circuit) frontend.Circuit { - - cValue := reflect.ValueOf(circuit).Elem() - newCircuit := reflect.New(cValue.Type()) - newCircuit.Elem().Set(cValue) - - circuitCopy, ok := newCircuit.Interface().(frontend.Circuit) - if !ok { - panic("couldn't clone the circuit") - } - - if !reflect.DeepEqual(circuitCopy, circuit) { - panic("clone failed") - } - - return circuitCopy -} - -func CopyWitness(to, from frontend.Circuit) { - var wValues []interface{} - - var collectHandler schema.LeafHandler = func(visibility schema.Visibility, name string, tInput reflect.Value) error { - v := tInput.Interface().(frontend.Variable) - - if visibility == schema.Secret || visibility == schema.Public { - if v == nil { - return fmt.Errorf("when parsing variable %s: missing assignment", name) - } - wValues = append(wValues, v) - } - return nil - } - if _, err := schema.Parse(from, tVariable, collectHandler); err != nil { - panic(err) - } - - i := 0 - var setHandler schema.LeafHandler = func(visibility schema.Visibility, name string, tInput reflect.Value) error { - if visibility == schema.Secret || visibility == schema.Public { - tInput.Set(reflect.ValueOf((wValues[i]))) - i++ - } - return nil - } - // this can't error. - _, _ = schema.Parse(to, tVariable, setHandler) - -} - -var tVariable reflect.Type - -func init() { - tVariable = reflect.ValueOf(struct{ A frontend.Variable }{}).FieldByName("A").Type() -} diff --git a/test/assert.go b/test/assert.go index fa807e5000..84b4efab2d 100644 --- a/test/assert.go +++ b/test/assert.go @@ -29,8 +29,7 @@ import ( "github.com/consensys/gnark/backend/plonk" "github.com/consensys/gnark/backend/witness" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/internal/backend/compiled" - "github.com/consensys/gnark/internal/utils" + "github.com/consensys/gnark/frontend/compiled" "github.com/stretchr/testify/require" ) @@ -340,7 +339,7 @@ func (assert *Assert) Fuzz(circuit frontend.Circuit, fuzzCount int, opts ...Test // first we clone the circuit // then we parse the frontend.Variable and set them to a random value or from our interesting pool // (% of allocations to be tuned) - w := utils.ShallowClone(circuit) + w := shallowClone(circuit) fillers := []filler{randomFiller, binaryFiller, seedFiller} diff --git a/test/engine.go b/test/engine.go index f623ad08ce..f7703758a9 100644 --- a/test/engine.go +++ b/test/engine.go @@ -20,11 +20,13 @@ import ( "fmt" "math/big" "path/filepath" + "reflect" "runtime" "strconv" "strings" "github.com/consensys/gnark/debug" + "github.com/consensys/gnark/frontend/schema" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" @@ -71,10 +73,10 @@ func IsSolved(circuit, witness frontend.Circuit, curveID ecc.ID, b backend.ID, o // then, we set all the variables values to the ones from the witness // clone the circuit - c := utils.ShallowClone(circuit) + c := shallowClone(circuit) // set the witness values - utils.CopyWitness(c, witness) + copyWitness(c, witness) defer func() { if r := recover(); r != nil { @@ -408,3 +410,55 @@ func (e *engine) Curve() ecc.ID { func (e *engine) Backend() backend.ID { return e.backendID } + +// shallowClone clones given circuit +// this is actually a shallow copy → if the circuits contains maps or slices +// only the reference is copied. +func shallowClone(circuit frontend.Circuit) frontend.Circuit { + + cValue := reflect.ValueOf(circuit).Elem() + newCircuit := reflect.New(cValue.Type()) + newCircuit.Elem().Set(cValue) + + circuitCopy, ok := newCircuit.Interface().(frontend.Circuit) + if !ok { + panic("couldn't clone the circuit") + } + + if !reflect.DeepEqual(circuitCopy, circuit) { + panic("clone failed") + } + + return circuitCopy +} + +func copyWitness(to, from frontend.Circuit) { + var wValues []interface{} + + var collectHandler schema.LeafHandler = func(visibility schema.Visibility, name string, tInput reflect.Value) error { + v := tInput.Interface().(frontend.Variable) + + if visibility == schema.Secret || visibility == schema.Public { + if v == nil { + return fmt.Errorf("when parsing variable %s: missing assignment", name) + } + wValues = append(wValues, v) + } + return nil + } + if _, err := schema.Parse(from, tVariable, collectHandler); err != nil { + panic(err) + } + + i := 0 + var setHandler schema.LeafHandler = func(visibility schema.Visibility, name string, tInput reflect.Value) error { + if visibility == schema.Secret || visibility == schema.Public { + tInput.Set(reflect.ValueOf((wValues[i]))) + i++ + } + return nil + } + // this can't error. + _, _ = schema.Parse(to, tVariable, setHandler) + +} From 5c5dab921a6acfe83087c23cd9f2ba3abefd1732 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 25 Feb 2022 10:57:45 -0600 Subject: [PATCH 07/20] refactor: frontend.Compile now takes a builder instead of backendID as parameter --- README.md | 2 +- backend/groth16/groth16.go | 5 --- backend/plonk/plonk.go | 5 --- circuitstats_test.go | 21 ++++++++--- debug_test.go | 16 +++++---- examples/plonk/main.go | 4 +-- examples/serialization/main.go | 4 +-- frontend/compile.go | 26 ++------------ frontend/compiled/cs.go | 21 ----------- frontend/cs/{plonk => scs}/api.go | 2 +- frontend/cs/{plonk => scs}/assertions.go | 2 +- frontend/cs/{plonk => scs}/conversion.go | 2 +- frontend/cs/{plonk => scs}/sparse_r1cs.go | 2 +- frontend/registry.go | 36 ------------------- internal/backend/bls12-377/cs/r1cs_test.go | 7 ++-- .../backend/bls12-377/groth16/groth16_test.go | 2 +- .../backend/bls12-377/plonk/plonk_test.go | 4 +-- internal/backend/bls12-381/cs/r1cs_test.go | 7 ++-- .../backend/bls12-381/groth16/groth16_test.go | 2 +- .../backend/bls12-381/plonk/plonk_test.go | 4 +-- internal/backend/bls24-315/cs/r1cs_test.go | 7 ++-- .../backend/bls24-315/groth16/groth16_test.go | 2 +- .../backend/bls24-315/plonk/plonk_test.go | 4 +-- internal/backend/bn254/cs/r1cs_test.go | 7 ++-- .../backend/bn254/groth16/groth16_test.go | 2 +- internal/backend/bn254/plonk/plonk_test.go | 4 +-- internal/backend/bw6-633/cs/r1cs_test.go | 7 ++-- .../backend/bw6-633/groth16/groth16_test.go | 2 +- internal/backend/bw6-633/plonk/plonk_test.go | 4 +-- internal/backend/bw6-761/cs/r1cs_test.go | 7 ++-- .../backend/bw6-761/groth16/groth16_test.go | 2 +- internal/backend/bw6-761/plonk/plonk_test.go | 4 +-- .../representations/tests/r1cs.go.tmpl | 7 ++-- .../zkpschemes/groth16/tests/groth16.go.tmpl | 2 +- .../zkpschemes/plonk/tests/plonk.go.tmpl | 6 ++-- std/algebra/fields_bls24315/e24_test.go | 12 +++---- std/algebra/sw_bls12377/g1_test.go | 11 +++--- std/algebra/sw_bls12377/g2_test.go | 8 ++--- std/algebra/sw_bls12377/pairing_test.go | 3 +- std/algebra/sw_bls24315/g1_test.go | 11 +++--- std/algebra/sw_bls24315/g2_test.go | 8 ++--- std/algebra/sw_bls24315/pairing_test.go | 3 +- .../twistededwards/bandersnatch/point_test.go | 14 ++++---- std/fiat-shamir/transcript_test.go | 4 +-- std/groth16_bls12377/verifier_test.go | 5 +-- std/groth16_bls24315/verifier_test.go | 5 +-- test/assert.go | 18 ++++++++-- 47 files changed, 141 insertions(+), 202 deletions(-) rename frontend/cs/{plonk => scs}/api.go (99%) rename frontend/cs/{plonk => scs}/assertions.go (99%) rename frontend/cs/{plonk => scs}/conversion.go (99%) rename frontend/cs/{plonk => scs}/sparse_r1cs.go (99%) delete mode 100644 frontend/registry.go diff --git a/README.md b/README.md index 0a21565df2..1f20c8b16d 100644 --- a/README.md +++ b/README.md @@ -83,7 +83,7 @@ func (circuit *CubicCircuit) Define(api frontend.API) error { // compiles our circuit into a R1CS var circuit CubicCircuit -ccs, err := frontend.Compile(ecc.BN254, backend.GROTH16, &circuit) +ccs, err := frontend.Compile(ecc.BN254, r1cs.NewBuilder, &circuit) // groth16 zkSNARK: Setup pk, vk, err := groth16.Setup(ccs) diff --git a/backend/groth16/groth16.go b/backend/groth16/groth16.go index c7f79554a5..1a286f5a3c 100644 --- a/backend/groth16/groth16.go +++ b/backend/groth16/groth16.go @@ -27,7 +27,6 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/r1cs" backend_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/cs" backend_bls12381 "github.com/consensys/gnark/internal/backend/bls12-381/cs" backend_bls24315 "github.com/consensys/gnark/internal/backend/bls24-315/cs" @@ -52,10 +51,6 @@ import ( groth16_bw6761 "github.com/consensys/gnark/internal/backend/bw6-761/groth16" ) -func init() { - frontend.RegisterDefaultBuilder(backend.GROTH16, r1cs.NewBuilder) -} - // TODO @gbotrel document hint functions here and in assert type groth16Object interface { diff --git a/backend/plonk/plonk.go b/backend/plonk/plonk.go index 158792bcb3..4676fe77b2 100644 --- a/backend/plonk/plonk.go +++ b/backend/plonk/plonk.go @@ -26,7 +26,6 @@ import ( "github.com/consensys/gnark-crypto/kzg" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/plonk" "github.com/consensys/gnark/backend/witness" cs_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/cs" @@ -58,10 +57,6 @@ import ( kzg_bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/kzg" ) -func init() { - frontend.RegisterDefaultBuilder(backend.PLONK, plonk.NewBuilder) -} - // Proof represents a Plonk proof generated by plonk.Prove // // it's underlying implementation is curve specific (see gnark/internal/backend) diff --git a/circuitstats_test.go b/circuitstats_test.go index 04e3b22964..12acf6da96 100644 --- a/circuitstats_test.go +++ b/circuitstats_test.go @@ -9,6 +9,8 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/internal/backend/circuits" "github.com/consensys/gnark/test" ) @@ -26,19 +28,30 @@ func TestCircuitStatistics(t *testing.T) { for _, curve := range ecc.Implemented() { for _, b := range backend.Implemented() { curve := curve - b := b + backendID := b name := k // copy the circuit now in case assert calls t.Parallel() tData := circuits.Circuits[k] assert.Run(func(assert *test.Assert) { - ccs, err := frontend.Compile(curve, b, tData.Circuit) + var newBuilder frontend.NewBuilder + + switch backendID { + case backend.GROTH16: + newBuilder = r1cs.NewBuilder + case backend.PLONK: + newBuilder = scs.NewBuilder + default: + panic("not implemented") + } + + ccs, err := frontend.Compile(curve, newBuilder, tData.Circuit) assert.NoError(err) // ensure we didn't introduce regressions that make circuits less efficient nbConstraints := ccs.GetNbConstraints() internal, secret, public := ccs.GetNbVariables() - checkStats(assert, name, nbConstraints, internal, secret, public, curve, b) - }, name, curve.String(), b.String()) + checkStats(assert, name, nbConstraints, internal, secret, public, curve, backendID) + }, name, curve.String(), backendID.String()) } } diff --git a/debug_test.go b/debug_test.go index 51910507d2..b28a2959ff 100644 --- a/debug_test.go +++ b/debug_test.go @@ -10,6 +10,8 @@ import ( "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/backend/plonk" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/test" "github.com/stretchr/testify/require" ) @@ -43,11 +45,11 @@ func TestPrintln(t *testing.T) { witness.B = 11 var expected bytes.Buffer - expected.WriteString("debug_test.go:25 13 is the addition\n") - expected.WriteString("debug_test.go:27 26 42\n") - expected.WriteString("debug_test.go:29 bits 1\n") - expected.WriteString("debug_test.go:30 circuit {A: 2, B: 11}\n") - expected.WriteString("debug_test.go:34 m .*\n") + expected.WriteString("debug_test.go:27 13 is the addition\n") + expected.WriteString("debug_test.go:29 26 42\n") + expected.WriteString("debug_test.go:31 bits 1\n") + expected.WriteString("debug_test.go:32 circuit {A: 2, B: 11}\n") + expected.WriteString("debug_test.go:36 m .*\n") { trace, _ := getGroth16Trace(&circuit, &witness) @@ -172,7 +174,7 @@ func TestTraceNotBoolean(t *testing.T) { } func getPlonkTrace(circuit, w frontend.Circuit) (string, error) { - ccs, err := frontend.Compile(ecc.BN254, backend.PLONK, circuit) + ccs, err := frontend.Compile(ecc.BN254, scs.NewBuilder, circuit) if err != nil { return "", err } @@ -196,7 +198,7 @@ func getPlonkTrace(circuit, w frontend.Circuit) (string, error) { } func getGroth16Trace(circuit, w frontend.Circuit) (string, error) { - ccs, err := frontend.Compile(ecc.BN254, backend.GROTH16, circuit) + ccs, err := frontend.Compile(ecc.BN254, r1cs.NewBuilder, circuit) if err != nil { return "", err } diff --git a/examples/plonk/main.go b/examples/plonk/main.go index 80a6587822..99563f7a9a 100644 --- a/examples/plonk/main.go +++ b/examples/plonk/main.go @@ -19,8 +19,8 @@ import ( "log" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/plonk" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/internal/backend/bn254/cs" "github.com/consensys/gnark/test" @@ -73,7 +73,7 @@ func main() { var circuit Circuit // // building the circuit... - ccs, err := frontend.Compile(ecc.BN254, backend.PLONK, &circuit) + ccs, err := frontend.Compile(ecc.BN254, scs.NewBuilder, &circuit) if err != nil { fmt.Println("circuit compilation error") } diff --git a/examples/serialization/main.go b/examples/serialization/main.go index dfbcc77c85..538784ca43 100644 --- a/examples/serialization/main.go +++ b/examples/serialization/main.go @@ -6,9 +6,9 @@ import ( "github.com/fxamacker/cbor/v2" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/examples/cubic" ) @@ -17,7 +17,7 @@ func main() { var circuit cubic.Circuit // compile a circuit - _r1cs, _ := frontend.Compile(ecc.BN254, backend.GROTH16, &circuit) + _r1cs, _ := frontend.Compile(ecc.BN254, r1cs.NewBuilder, &circuit) // R1CS implements io.WriterTo and io.ReaderFrom var buf bytes.Buffer diff --git a/frontend/compile.go b/frontend/compile.go index 4565063d0f..2af91d41ba 100644 --- a/frontend/compile.go +++ b/frontend/compile.go @@ -6,7 +6,6 @@ import ( "reflect" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/debug" "github.com/consensys/gnark/frontend/schema" ) @@ -47,7 +46,7 @@ type NewBuilder func(ecc.ID) (Builder, error) // // initialCapacity is an optional parameter that reserves memory in slices // it should be set to the estimated number of constraints in the circuit, if known. -func Compile(curveID ecc.ID, zkpID backend.ID, circuit Circuit, opts ...CompileOption) (CompiledConstraintSystem, error) { +func Compile(curveID ecc.ID, newBuilder NewBuilder, circuit Circuit, opts ...CompileOption) (CompiledConstraintSystem, error) { // setup option opt := compileConfig{} for _, o := range opts { @@ -55,16 +54,7 @@ func Compile(curveID ecc.ID, zkpID backend.ID, circuit Circuit, opts ...CompileO return nil, fmt.Errorf("apply option: %w", err) } } - newBuilder := opt.newBuilder - if newBuilder == nil { - var ok bool - backendsM.RLock() - newBuilder, ok = backends[zkpID] - backendsM.RUnlock() - if !ok { - return nil, fmt.Errorf("no default constraint builder registered nor set as option") - } - } + builder, err := newBuilder(curveID) if err != nil { return nil, fmt.Errorf("new builder: %w", err) @@ -143,7 +133,6 @@ type CompileOption func(opt *compileConfig) error type compileConfig struct { capacity int ignoreUnconstrainedInputs bool - newBuilder NewBuilder } // WithCapacity is a compile option that specifies the estimated capacity needed @@ -169,14 +158,3 @@ func IgnoreUnconstrainedInputs() CompileOption { return nil } } - -// WithBuilder is a compile option which enables the compiler to build the -// constraint system with a user-defined builder. -// -// /!\ This is highly experimental and may change in upcoming releases /!\ -func WithBuilder(builder NewBuilder) CompileOption { - return func(opt *compileConfig) error { - opt.newBuilder = builder - return nil - } -} diff --git a/frontend/compiled/cs.go b/frontend/compiled/cs.go index 5a8deaf6ee..a2070579af 100644 --- a/frontend/compiled/cs.go +++ b/frontend/compiled/cs.go @@ -2,7 +2,6 @@ package compiled import ( "fmt" - "io" "strings" "github.com/consensys/gnark-crypto/ecc" @@ -56,31 +55,11 @@ func (cs *ConstraintSystem) GetNbVariables() (internal, secret, public int) { return cs.NbInternalVariables, cs.NbSecretVariables, cs.NbPublicVariables } -// FrSize panics -func (cs *ConstraintSystem) FrSize() int { panic("not implemented") } - -// GetNbCoefficients panics -func (cs *ConstraintSystem) GetNbCoefficients() int { panic("not implemented") } - -// // CurveID returns ecc.UNKNOWN -// func (cs *CS) CurveID() ecc.ID { return ecc.UNKNOWN } - -// WriteTo panics -func (cs *ConstraintSystem) WriteTo(w io.Writer) (n int64, err error) { panic("not implemented") } - -// ReadFrom panics -func (cs *ConstraintSystem) ReadFrom(r io.Reader) (n int64, err error) { panic("not implemented") } - -// ToHTML panics -func (cs *ConstraintSystem) ToHTML(w io.Writer) error { panic("not implemtened") } - // GetCounters return the collected constraint counters, if any func (cs *ConstraintSystem) GetCounters() []Counter { return cs.Counters } func (cs *ConstraintSystem) GetSchema() *schema.Schema { return cs.Schema } -func (cs *ConstraintSystem) GetConstraints() [][]string { panic("not implemented") } - // Counter contains measurements of useful statistics between two Tag type Counter struct { From, To string diff --git a/frontend/cs/plonk/api.go b/frontend/cs/scs/api.go similarity index 99% rename from frontend/cs/plonk/api.go rename to frontend/cs/scs/api.go index d1c75b5764..191dd6afd3 100644 --- a/frontend/cs/plonk/api.go +++ b/frontend/cs/scs/api.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package plonk +package scs import ( "fmt" diff --git a/frontend/cs/plonk/assertions.go b/frontend/cs/scs/assertions.go similarity index 99% rename from frontend/cs/plonk/assertions.go rename to frontend/cs/scs/assertions.go index 7657d64d7d..e80ca9634e 100644 --- a/frontend/cs/plonk/assertions.go +++ b/frontend/cs/scs/assertions.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package plonk +package scs import ( "fmt" diff --git a/frontend/cs/plonk/conversion.go b/frontend/cs/scs/conversion.go similarity index 99% rename from frontend/cs/plonk/conversion.go rename to frontend/cs/scs/conversion.go index 1710a3f4ed..f5bb6bde28 100644 --- a/frontend/cs/plonk/conversion.go +++ b/frontend/cs/scs/conversion.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package plonk +package scs import ( "github.com/consensys/gnark-crypto/ecc" diff --git a/frontend/cs/plonk/sparse_r1cs.go b/frontend/cs/scs/sparse_r1cs.go similarity index 99% rename from frontend/cs/plonk/sparse_r1cs.go rename to frontend/cs/scs/sparse_r1cs.go index 01afe21a91..0afb18aa83 100644 --- a/frontend/cs/plonk/sparse_r1cs.go +++ b/frontend/cs/scs/sparse_r1cs.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package plonk +package scs import ( "errors" diff --git a/frontend/registry.go b/frontend/registry.go deleted file mode 100644 index 6c2b3b7642..0000000000 --- a/frontend/registry.go +++ /dev/null @@ -1,36 +0,0 @@ -package frontend - -import ( - "fmt" - "sync" - - "github.com/consensys/gnark/backend" -) - -var ( - backends = make(map[backend.ID]NewBuilder) - backendsM sync.RWMutex -) - -// RegisterDefaultBuilder registers a frontend f for a backend b. This registration -// ensures that a correct frontend system is chosen for a specific backend when -// compiling a circuit. The method does not check that the compiler for that -// frontend is already registered and the compiler is looked up during compile -// time. It is an error to double-assign a frontend to a single backend and the -// mehod panics. -// -// /!\ This is highly experimental and may change in upcoming releases /!\ -func RegisterDefaultBuilder(b backend.ID, builder NewBuilder) { - if b == backend.UNKNOWN { - panic("can not assign builder to unknown backend") - } - - // a frontend may be assigned before a compiler to that frontend is - // registered. we perform frontend compiler lookup during compilation. - backendsM.Lock() - defer backendsM.Unlock() - if _, ok := backends[b]; ok { - panic(fmt.Sprintf("double frontend registration for backend '%s'", b)) - } - backends[b] = builder -} diff --git a/internal/backend/bls12-377/cs/r1cs_test.go b/internal/backend/bls12-377/cs/r1cs_test.go index e63b16c7d6..c8b2113d59 100644 --- a/internal/backend/bls12-377/cs/r1cs_test.go +++ b/internal/backend/bls12-377/cs/r1cs_test.go @@ -19,7 +19,6 @@ package cs_test import ( "bytes" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/internal/backend/circuits" @@ -37,7 +36,7 @@ func TestSerialization(t *testing.T) { t.Run(name, func(t *testing.T) { tc := circuits.Circuits[name] - r1cs1, err := frontend.Compile(ecc.BLS12_377, backend.UNKNOWN, tc.Circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs1, err := frontend.Compile(ecc.BLS12_377, r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) } @@ -46,7 +45,7 @@ func TestSerialization(t *testing.T) { } // copmpile a second time to ensure determinism - r1cs2, err := frontend.Compile(ecc.BLS12_377, backend.UNKNOWN, tc.Circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs2, err := frontend.Compile(ecc.BLS12_377, r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) } @@ -135,7 +134,7 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { var c circuit - ccs, err := frontend.Compile(ecc.BLS12_377, backend.UNKNOWN, &c, frontend.WithBuilder(r1cs.NewBuilder)) + ccs, err := frontend.Compile(ecc.BLS12_377, r1cs.NewBuilder, &c) if err != nil { b.Fatal(err) } diff --git a/internal/backend/bls12-377/groth16/groth16_test.go b/internal/backend/bls12-377/groth16/groth16_test.go index a467692558..9d377c713f 100644 --- a/internal/backend/bls12-377/groth16/groth16_test.go +++ b/internal/backend/bls12-377/groth16/groth16_test.go @@ -58,7 +58,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit) { circuit := refCircuit{ nbConstraints: nbConstraints, } - r1cs, err := frontend.Compile(curve.ID, backend.UNKNOWN, &circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs, err := frontend.Compile(curve.ID, r1cs.NewBuilder, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bls12-377/plonk/plonk_test.go b/internal/backend/bls12-377/plonk/plonk_test.go index 8c9affeab7..192ecc0324 100644 --- a/internal/backend/bls12-377/plonk/plonk_test.go +++ b/internal/backend/bls12-377/plonk/plonk_test.go @@ -36,7 +36,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/plonk" + "github.com/consensys/gnark/frontend/cs/scs" ) //--------------------// @@ -62,7 +62,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit, *k circuit := refCircuit{ nbConstraints: nbConstraints, } - ccs, err := frontend.Compile(curve.ID, backend.UNKNOWN, &circuit, frontend.WithBuilder(plonk.NewBuilder)) + ccs, err := frontend.Compile(curve.ID, scs.NewBuilder, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bls12-381/cs/r1cs_test.go b/internal/backend/bls12-381/cs/r1cs_test.go index a96a65d21c..57f05cc18b 100644 --- a/internal/backend/bls12-381/cs/r1cs_test.go +++ b/internal/backend/bls12-381/cs/r1cs_test.go @@ -19,7 +19,6 @@ package cs_test import ( "bytes" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/internal/backend/circuits" @@ -37,7 +36,7 @@ func TestSerialization(t *testing.T) { t.Run(name, func(t *testing.T) { tc := circuits.Circuits[name] - r1cs1, err := frontend.Compile(ecc.BLS12_381, backend.UNKNOWN, tc.Circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs1, err := frontend.Compile(ecc.BLS12_381, r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) } @@ -46,7 +45,7 @@ func TestSerialization(t *testing.T) { } // copmpile a second time to ensure determinism - r1cs2, err := frontend.Compile(ecc.BLS12_381, backend.UNKNOWN, tc.Circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs2, err := frontend.Compile(ecc.BLS12_381, r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) } @@ -135,7 +134,7 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { var c circuit - ccs, err := frontend.Compile(ecc.BLS12_381, backend.UNKNOWN, &c, frontend.WithBuilder(r1cs.NewBuilder)) + ccs, err := frontend.Compile(ecc.BLS12_381, r1cs.NewBuilder, &c) if err != nil { b.Fatal(err) } diff --git a/internal/backend/bls12-381/groth16/groth16_test.go b/internal/backend/bls12-381/groth16/groth16_test.go index 721e57e6bf..3bf5084315 100644 --- a/internal/backend/bls12-381/groth16/groth16_test.go +++ b/internal/backend/bls12-381/groth16/groth16_test.go @@ -58,7 +58,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit) { circuit := refCircuit{ nbConstraints: nbConstraints, } - r1cs, err := frontend.Compile(curve.ID, backend.UNKNOWN, &circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs, err := frontend.Compile(curve.ID, r1cs.NewBuilder, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bls12-381/plonk/plonk_test.go b/internal/backend/bls12-381/plonk/plonk_test.go index 5332ba3016..c653d99b0f 100644 --- a/internal/backend/bls12-381/plonk/plonk_test.go +++ b/internal/backend/bls12-381/plonk/plonk_test.go @@ -36,7 +36,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/plonk" + "github.com/consensys/gnark/frontend/cs/scs" ) //--------------------// @@ -62,7 +62,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit, *k circuit := refCircuit{ nbConstraints: nbConstraints, } - ccs, err := frontend.Compile(curve.ID, backend.UNKNOWN, &circuit, frontend.WithBuilder(plonk.NewBuilder)) + ccs, err := frontend.Compile(curve.ID, scs.NewBuilder, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bls24-315/cs/r1cs_test.go b/internal/backend/bls24-315/cs/r1cs_test.go index 931504b2f7..85ea52bff8 100644 --- a/internal/backend/bls24-315/cs/r1cs_test.go +++ b/internal/backend/bls24-315/cs/r1cs_test.go @@ -19,7 +19,6 @@ package cs_test import ( "bytes" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/internal/backend/circuits" @@ -37,7 +36,7 @@ func TestSerialization(t *testing.T) { t.Run(name, func(t *testing.T) { tc := circuits.Circuits[name] - r1cs1, err := frontend.Compile(ecc.BLS24_315, backend.UNKNOWN, tc.Circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs1, err := frontend.Compile(ecc.BLS24_315, r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) } @@ -46,7 +45,7 @@ func TestSerialization(t *testing.T) { } // copmpile a second time to ensure determinism - r1cs2, err := frontend.Compile(ecc.BLS24_315, backend.UNKNOWN, tc.Circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs2, err := frontend.Compile(ecc.BLS24_315, r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) } @@ -135,7 +134,7 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { var c circuit - ccs, err := frontend.Compile(ecc.BLS24_315, backend.UNKNOWN, &c, frontend.WithBuilder(r1cs.NewBuilder)) + ccs, err := frontend.Compile(ecc.BLS24_315, r1cs.NewBuilder, &c) if err != nil { b.Fatal(err) } diff --git a/internal/backend/bls24-315/groth16/groth16_test.go b/internal/backend/bls24-315/groth16/groth16_test.go index af5f319940..de108512a5 100644 --- a/internal/backend/bls24-315/groth16/groth16_test.go +++ b/internal/backend/bls24-315/groth16/groth16_test.go @@ -58,7 +58,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit) { circuit := refCircuit{ nbConstraints: nbConstraints, } - r1cs, err := frontend.Compile(curve.ID, backend.UNKNOWN, &circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs, err := frontend.Compile(curve.ID, r1cs.NewBuilder, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bls24-315/plonk/plonk_test.go b/internal/backend/bls24-315/plonk/plonk_test.go index 71378ce85f..44dcbfe868 100644 --- a/internal/backend/bls24-315/plonk/plonk_test.go +++ b/internal/backend/bls24-315/plonk/plonk_test.go @@ -36,7 +36,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/plonk" + "github.com/consensys/gnark/frontend/cs/scs" ) //--------------------// @@ -62,7 +62,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit, *k circuit := refCircuit{ nbConstraints: nbConstraints, } - ccs, err := frontend.Compile(curve.ID, backend.UNKNOWN, &circuit, frontend.WithBuilder(plonk.NewBuilder)) + ccs, err := frontend.Compile(curve.ID, scs.NewBuilder, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bn254/cs/r1cs_test.go b/internal/backend/bn254/cs/r1cs_test.go index c2eb0a7663..2cbd8fb574 100644 --- a/internal/backend/bn254/cs/r1cs_test.go +++ b/internal/backend/bn254/cs/r1cs_test.go @@ -19,7 +19,6 @@ package cs_test import ( "bytes" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/internal/backend/circuits" @@ -37,7 +36,7 @@ func TestSerialization(t *testing.T) { t.Run(name, func(t *testing.T) { tc := circuits.Circuits[name] - r1cs1, err := frontend.Compile(ecc.BN254, backend.UNKNOWN, tc.Circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs1, err := frontend.Compile(ecc.BN254, r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) } @@ -46,7 +45,7 @@ func TestSerialization(t *testing.T) { } // copmpile a second time to ensure determinism - r1cs2, err := frontend.Compile(ecc.BN254, backend.UNKNOWN, tc.Circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs2, err := frontend.Compile(ecc.BN254, r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) } @@ -135,7 +134,7 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { var c circuit - ccs, err := frontend.Compile(ecc.BN254, backend.UNKNOWN, &c, frontend.WithBuilder(r1cs.NewBuilder)) + ccs, err := frontend.Compile(ecc.BN254, r1cs.NewBuilder, &c) if err != nil { b.Fatal(err) } diff --git a/internal/backend/bn254/groth16/groth16_test.go b/internal/backend/bn254/groth16/groth16_test.go index 59667f5ce2..fc4735fe87 100644 --- a/internal/backend/bn254/groth16/groth16_test.go +++ b/internal/backend/bn254/groth16/groth16_test.go @@ -58,7 +58,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit) { circuit := refCircuit{ nbConstraints: nbConstraints, } - r1cs, err := frontend.Compile(curve.ID, backend.UNKNOWN, &circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs, err := frontend.Compile(curve.ID, r1cs.NewBuilder, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bn254/plonk/plonk_test.go b/internal/backend/bn254/plonk/plonk_test.go index 0fc5ff68ab..67bc45b90c 100644 --- a/internal/backend/bn254/plonk/plonk_test.go +++ b/internal/backend/bn254/plonk/plonk_test.go @@ -36,7 +36,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/plonk" + "github.com/consensys/gnark/frontend/cs/scs" ) //--------------------// @@ -62,7 +62,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit, *k circuit := refCircuit{ nbConstraints: nbConstraints, } - ccs, err := frontend.Compile(curve.ID, backend.UNKNOWN, &circuit, frontend.WithBuilder(plonk.NewBuilder)) + ccs, err := frontend.Compile(curve.ID, scs.NewBuilder, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bw6-633/cs/r1cs_test.go b/internal/backend/bw6-633/cs/r1cs_test.go index 5900e2b650..2f1ff61692 100644 --- a/internal/backend/bw6-633/cs/r1cs_test.go +++ b/internal/backend/bw6-633/cs/r1cs_test.go @@ -19,7 +19,6 @@ package cs_test import ( "bytes" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/internal/backend/circuits" @@ -37,7 +36,7 @@ func TestSerialization(t *testing.T) { t.Run(name, func(t *testing.T) { tc := circuits.Circuits[name] - r1cs1, err := frontend.Compile(ecc.BW6_633, backend.UNKNOWN, tc.Circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs1, err := frontend.Compile(ecc.BW6_633, r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) } @@ -46,7 +45,7 @@ func TestSerialization(t *testing.T) { } // copmpile a second time to ensure determinism - r1cs2, err := frontend.Compile(ecc.BW6_633, backend.UNKNOWN, tc.Circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs2, err := frontend.Compile(ecc.BW6_633, r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) } @@ -135,7 +134,7 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { var c circuit - ccs, err := frontend.Compile(ecc.BW6_633, backend.UNKNOWN, &c, frontend.WithBuilder(r1cs.NewBuilder)) + ccs, err := frontend.Compile(ecc.BW6_633, r1cs.NewBuilder, &c) if err != nil { b.Fatal(err) } diff --git a/internal/backend/bw6-633/groth16/groth16_test.go b/internal/backend/bw6-633/groth16/groth16_test.go index deb017f777..188d648bdf 100644 --- a/internal/backend/bw6-633/groth16/groth16_test.go +++ b/internal/backend/bw6-633/groth16/groth16_test.go @@ -58,7 +58,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit) { circuit := refCircuit{ nbConstraints: nbConstraints, } - r1cs, err := frontend.Compile(curve.ID, backend.UNKNOWN, &circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs, err := frontend.Compile(curve.ID, r1cs.NewBuilder, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bw6-633/plonk/plonk_test.go b/internal/backend/bw6-633/plonk/plonk_test.go index fa8b550794..93c9b44350 100644 --- a/internal/backend/bw6-633/plonk/plonk_test.go +++ b/internal/backend/bw6-633/plonk/plonk_test.go @@ -36,7 +36,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/plonk" + "github.com/consensys/gnark/frontend/cs/scs" ) //--------------------// @@ -62,7 +62,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit, *k circuit := refCircuit{ nbConstraints: nbConstraints, } - ccs, err := frontend.Compile(curve.ID, backend.UNKNOWN, &circuit, frontend.WithBuilder(plonk.NewBuilder)) + ccs, err := frontend.Compile(curve.ID, scs.NewBuilder, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bw6-761/cs/r1cs_test.go b/internal/backend/bw6-761/cs/r1cs_test.go index 94eec2486c..71763a400c 100644 --- a/internal/backend/bw6-761/cs/r1cs_test.go +++ b/internal/backend/bw6-761/cs/r1cs_test.go @@ -19,7 +19,6 @@ package cs_test import ( "bytes" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/internal/backend/circuits" @@ -41,7 +40,7 @@ func TestSerialization(t *testing.T) { return } - r1cs1, err := frontend.Compile(ecc.BW6_761, backend.UNKNOWN, tc.Circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs1, err := frontend.Compile(ecc.BW6_761, r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) } @@ -50,7 +49,7 @@ func TestSerialization(t *testing.T) { } // copmpile a second time to ensure determinism - r1cs2, err := frontend.Compile(ecc.BW6_761, backend.UNKNOWN, tc.Circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs2, err := frontend.Compile(ecc.BW6_761, r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) } @@ -139,7 +138,7 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { var c circuit - ccs, err := frontend.Compile(ecc.BW6_761, backend.UNKNOWN, &c, frontend.WithBuilder(r1cs.NewBuilder)) + ccs, err := frontend.Compile(ecc.BW6_761, r1cs.NewBuilder, &c) if err != nil { b.Fatal(err) } diff --git a/internal/backend/bw6-761/groth16/groth16_test.go b/internal/backend/bw6-761/groth16/groth16_test.go index 6951cd4a76..eebefa3488 100644 --- a/internal/backend/bw6-761/groth16/groth16_test.go +++ b/internal/backend/bw6-761/groth16/groth16_test.go @@ -58,7 +58,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit) { circuit := refCircuit{ nbConstraints: nbConstraints, } - r1cs, err := frontend.Compile(curve.ID, backend.UNKNOWN, &circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs, err := frontend.Compile(curve.ID, r1cs.NewBuilder, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bw6-761/plonk/plonk_test.go b/internal/backend/bw6-761/plonk/plonk_test.go index d4a25ddaf0..f44eb8982b 100644 --- a/internal/backend/bw6-761/plonk/plonk_test.go +++ b/internal/backend/bw6-761/plonk/plonk_test.go @@ -36,7 +36,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/plonk" + "github.com/consensys/gnark/frontend/cs/scs" ) //--------------------// @@ -62,7 +62,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit, *k circuit := refCircuit{ nbConstraints: nbConstraints, } - ccs, err := frontend.Compile(curve.ID, backend.UNKNOWN, &circuit, frontend.WithBuilder(plonk.NewBuilder)) + ccs, err := frontend.Compile(curve.ID, scs.NewBuilder, &circuit) if err != nil { panic(err) } diff --git a/internal/generator/backend/template/representations/tests/r1cs.go.tmpl b/internal/generator/backend/template/representations/tests/r1cs.go.tmpl index 852a1583fe..2cce6f0613 100644 --- a/internal/generator/backend/template/representations/tests/r1cs.go.tmpl +++ b/internal/generator/backend/template/representations/tests/r1cs.go.tmpl @@ -3,7 +3,6 @@ import ( "bytes" "testing" "reflect" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/internal/backend/circuits" @@ -25,7 +24,7 @@ func TestSerialization(t *testing.T) { } {{end}} - r1cs1, err := frontend.Compile(ecc.{{ .CurveID }}, backend.UNKNOWN, tc.Circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs1, err := frontend.Compile(ecc.{{ .CurveID }}, r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) } @@ -34,7 +33,7 @@ func TestSerialization(t *testing.T) { } // copmpile a second time to ensure determinism - r1cs2, err := frontend.Compile(ecc.{{ .CurveID }}, backend.UNKNOWN, tc.Circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs2, err := frontend.Compile(ecc.{{ .CurveID }},r1cs.NewBuilder, tc.Circuit) if err != nil { t.Fatal(err) } @@ -124,7 +123,7 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { var c circuit - ccs, err := frontend.Compile(ecc.{{ .CurveID }}, backend.UNKNOWN, &c, frontend.WithBuilder(r1cs.NewBuilder)) + ccs, err := frontend.Compile(ecc.{{ .CurveID }},r1cs.NewBuilder, &c) if err != nil { b.Fatal(err) } diff --git a/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.go.tmpl index b83cda5236..c6d4305a4b 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.go.tmpl @@ -38,7 +38,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit) { circuit := refCircuit{ nbConstraints: nbConstraints, } - r1cs, err := frontend.Compile(curve.ID, backend.UNKNOWN, &circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs, err := frontend.Compile(curve.ID,r1cs.NewBuilder, &circuit) if err != nil { panic(err) } diff --git a/internal/generator/backend/template/zkpschemes/plonk/tests/plonk.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/tests/plonk.go.tmpl index e378d356f6..140e5854f4 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/tests/plonk.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/tests/plonk.go.tmpl @@ -10,10 +10,10 @@ import ( "testing" "reflect" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/plonk" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/frontend/cs/scs" ) @@ -43,7 +43,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit, *k circuit := refCircuit{ nbConstraints: nbConstraints, } - ccs, err := frontend.Compile(curve.ID, backend.UNKNOWN, &circuit, frontend.WithBuilder(plonk.NewBuilder)) + ccs, err := frontend.Compile(curve.ID, scs.NewBuilder, &circuit) if err != nil { panic(err) } diff --git a/std/algebra/fields_bls24315/e24_test.go b/std/algebra/fields_bls24315/e24_test.go index 526ea0828e..1fc7c69764 100644 --- a/std/algebra/fields_bls24315/e24_test.go +++ b/std/algebra/fields_bls24315/e24_test.go @@ -21,8 +21,8 @@ import ( "github.com/consensys/gnark-crypto/ecc" bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/test" ) @@ -414,7 +414,7 @@ func BenchmarkMulE24(b *testing.B) { var c fp24Mul b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_633, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewBuilder, &c) } }) @@ -425,7 +425,7 @@ func BenchmarkSquareE24(b *testing.B) { var c fp24Square b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_633, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewBuilder, &c) } }) @@ -436,7 +436,7 @@ func BenchmarkInverseE24(b *testing.B) { var c fp24Inverse b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_633, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewBuilder, &c) } }) @@ -447,7 +447,7 @@ func BenchmarkConjugateE24(b *testing.B) { var c fp24Conjugate b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_633, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewBuilder, &c) } }) @@ -458,7 +458,7 @@ func BenchmarkMulBy034E24(b *testing.B) { var c fp24MulBy034 b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_633, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewBuilder, &c) } }) diff --git a/std/algebra/sw_bls12377/g1_test.go b/std/algebra/sw_bls12377/g1_test.go index e1f696ead9..8736284061 100644 --- a/std/algebra/sw_bls12377/g1_test.go +++ b/std/algebra/sw_bls12377/g1_test.go @@ -22,8 +22,9 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/test" bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" @@ -391,7 +392,7 @@ func BenchmarkConstScalarMulG1(b *testing.B) { c.R = r b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_761, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_761, r1cs.NewBuilder, &c) } }) @@ -399,7 +400,7 @@ func BenchmarkConstScalarMulG1(b *testing.B) { b.Run("plonk", func(b *testing.B) { var err error for i := 0; i < b.N; i++ { - ccsBench, err = frontend.Compile(ecc.BW6_761, backend.PLONK, &c) + ccsBench, err = frontend.Compile(ecc.BW6_761, scs.NewBuilder, &c) if err != nil { b.Fatal(err) } @@ -420,7 +421,7 @@ func BenchmarkVarScalarMulG1(b *testing.B) { c.R = r b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_761, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_761, r1cs.NewBuilder, &c) } }) @@ -428,7 +429,7 @@ func BenchmarkVarScalarMulG1(b *testing.B) { b.Run("plonk", func(b *testing.B) { var err error for i := 0; i < b.N; i++ { - ccsBench, err = frontend.Compile(ecc.BW6_761, backend.PLONK, &c) + ccsBench, err = frontend.Compile(ecc.BW6_761, scs.NewBuilder, &c) if err != nil { b.Fatal(err) } diff --git a/std/algebra/sw_bls12377/g2_test.go b/std/algebra/sw_bls12377/g2_test.go index a32e0e145b..fd87276f8e 100644 --- a/std/algebra/sw_bls12377/g2_test.go +++ b/std/algebra/sw_bls12377/g2_test.go @@ -22,8 +22,8 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/std/algebra/fields_bls12377" "github.com/consensys/gnark/test" @@ -272,7 +272,7 @@ func BenchmarkDoubleAffineG2(b *testing.B) { var c g2DoubleAffine b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_761, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_761, r1cs.NewBuilder, &c) } }) @@ -283,7 +283,7 @@ func BenchmarkAddAssignAffineG2(b *testing.B) { var c g2AddAssignAffine b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_761, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_761, r1cs.NewBuilder, &c) } }) @@ -294,7 +294,7 @@ func BenchmarkDoubleAndAddAffineG2(b *testing.B) { var c g2DoubleAndAddAffine b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_761, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_761, r1cs.NewBuilder, &c) } }) diff --git a/std/algebra/sw_bls12377/pairing_test.go b/std/algebra/sw_bls12377/pairing_test.go index 81185f5494..645a5e0307 100644 --- a/std/algebra/sw_bls12377/pairing_test.go +++ b/std/algebra/sw_bls12377/pairing_test.go @@ -25,6 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/std/algebra/fields_bls12377" "github.com/consensys/gnark/test" ) @@ -199,7 +200,7 @@ func BenchmarkPairing(b *testing.B) { var c pairingBLS377 b.ResetTimer() for i := 0; i < b.N; i++ { - frontend.Compile(ecc.BW6_761, backend.PLONK, &c) + frontend.Compile(ecc.BW6_761, scs.NewBuilder, &c) } // ccsBench, _ = compiler.Compile(ecc.BW6_761, backend.GROTH16, &c) // b.Log("groth16", ccsBench.GetNbConstraints()) diff --git a/std/algebra/sw_bls24315/g1_test.go b/std/algebra/sw_bls24315/g1_test.go index 663799eaac..c2a4ef314b 100644 --- a/std/algebra/sw_bls24315/g1_test.go +++ b/std/algebra/sw_bls24315/g1_test.go @@ -22,8 +22,9 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/test" bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315" @@ -391,7 +392,7 @@ func BenchmarkConstScalarMulG1(b *testing.B) { c.R = r b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_633, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewBuilder, &c) } }) @@ -399,7 +400,7 @@ func BenchmarkConstScalarMulG1(b *testing.B) { b.Run("plonk", func(b *testing.B) { var err error for i := 0; i < b.N; i++ { - ccsBench, err = frontend.Compile(ecc.BW6_633, backend.PLONK, &c) + ccsBench, err = frontend.Compile(ecc.BW6_633, scs.NewBuilder, &c) if err != nil { b.Fatal(err) } @@ -420,7 +421,7 @@ func BenchmarkVarScalarMulG1(b *testing.B) { c.R = r b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_633, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewBuilder, &c) } }) @@ -428,7 +429,7 @@ func BenchmarkVarScalarMulG1(b *testing.B) { b.Run("plonk", func(b *testing.B) { var err error for i := 0; i < b.N; i++ { - ccsBench, err = frontend.Compile(ecc.BW6_633, backend.PLONK, &c) + ccsBench, err = frontend.Compile(ecc.BW6_633, scs.NewBuilder, &c) if err != nil { b.Fatal(err) } diff --git a/std/algebra/sw_bls24315/g2_test.go b/std/algebra/sw_bls24315/g2_test.go index af80416be3..6d5bf22e86 100644 --- a/std/algebra/sw_bls24315/g2_test.go +++ b/std/algebra/sw_bls24315/g2_test.go @@ -22,8 +22,8 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/std/algebra/fields_bls24315" "github.com/consensys/gnark/test" @@ -272,7 +272,7 @@ func BenchmarkDoubleAffineG2(b *testing.B) { var c g2DoubleAffine b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_633, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewBuilder, &c) } }) @@ -283,7 +283,7 @@ func BenchmarkAddAssignAffineG2(b *testing.B) { var c g2AddAssignAffine b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_633, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewBuilder, &c) } }) @@ -294,7 +294,7 @@ func BenchmarkDoubleAndAddAffineG2(b *testing.B) { var c g2DoubleAndAddAffine b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_633, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewBuilder, &c) } }) diff --git a/std/algebra/sw_bls24315/pairing_test.go b/std/algebra/sw_bls24315/pairing_test.go index a858c7c19b..6b53f37178 100644 --- a/std/algebra/sw_bls24315/pairing_test.go +++ b/std/algebra/sw_bls24315/pairing_test.go @@ -25,6 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/std/algebra/fields_bls24315" "github.com/consensys/gnark/test" ) @@ -213,6 +214,6 @@ func TestTriplePairingBLS24315(t *testing.T) { func BenchmarkPairing(b *testing.B) { var c pairingBLS24315 - ccsBench, _ = frontend.Compile(ecc.BW6_633, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewBuilder, &c) b.Log("groth16", ccsBench.GetNbConstraints()) } diff --git a/std/algebra/twistededwards/bandersnatch/point_test.go b/std/algebra/twistededwards/bandersnatch/point_test.go index ff879c2fa2..c02ec8f7d7 100644 --- a/std/algebra/twistededwards/bandersnatch/point_test.go +++ b/std/algebra/twistededwards/bandersnatch/point_test.go @@ -22,8 +22,8 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-381/twistededwards/bandersnatch" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/test" ) @@ -364,36 +364,36 @@ func TestNeg(t *testing.T) { // Bench func BenchmarkDouble(b *testing.B) { var c double - ccsBench, _ := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c) + ccsBench, _ := frontend.Compile(ecc.BLS12_381, r1cs.NewBuilder, &c) b.Log("groth16", ccsBench.GetNbConstraints()) } func BenchmarkAddGeneric(b *testing.B) { var c addGeneric - ccsBench, _ := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c) + ccsBench, _ := frontend.Compile(ecc.BLS12_381, r1cs.NewBuilder, &c) b.Log("groth16", ccsBench.GetNbConstraints()) } func BenchmarkAddFixedPoint(b *testing.B) { var c add - ccsBench, _ := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c) + ccsBench, _ := frontend.Compile(ecc.BLS12_381, r1cs.NewBuilder, &c) b.Log("groth16", ccsBench.GetNbConstraints()) } func BenchmarkMustBeOnCurve(b *testing.B) { var c mustBeOnCurve - ccsBench, _ := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c) + ccsBench, _ := frontend.Compile(ecc.BLS12_381, r1cs.NewBuilder, &c) b.Log("groth16", ccsBench.GetNbConstraints()) } func BenchmarkScalarMulGeneric(b *testing.B) { var c scalarMulGeneric - ccsBench, _ := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c) + ccsBench, _ := frontend.Compile(ecc.BLS12_381, r1cs.NewBuilder, &c) b.Log("groth16", ccsBench.GetNbConstraints()) } func BenchmarkScalarMulFixed(b *testing.B) { var c scalarMulFixed - ccsBench, _ := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c) + ccsBench, _ := frontend.Compile(ecc.BLS12_381, r1cs.NewBuilder, &c) b.Log("groth16", ccsBench.GetNbConstraints()) } diff --git a/std/fiat-shamir/transcript_test.go b/std/fiat-shamir/transcript_test.go index ab728cb05c..b42b7bac21 100644 --- a/std/fiat-shamir/transcript_test.go +++ b/std/fiat-shamir/transcript_test.go @@ -23,8 +23,8 @@ import ( "github.com/consensys/gnark-crypto/ecc" fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/test" ) @@ -156,7 +156,7 @@ func BenchmarkCompile(b *testing.B) { var ccs frontend.CompiledConstraintSystem b.ResetTimer() for i := 0; i < b.N; i++ { - ccs, _ = frontend.Compile(ecc.BN254, backend.PLONK, &circuit) + ccs, _ = frontend.Compile(ecc.BN254, scs.NewBuilder, &circuit) } b.Log(ccs.GetNbConstraints()) } diff --git a/std/groth16_bls12377/verifier_test.go b/std/groth16_bls12377/verifier_test.go index fc571f7115..75eda2035e 100644 --- a/std/groth16_bls12377/verifier_test.go +++ b/std/groth16_bls12377/verifier_test.go @@ -24,6 +24,7 @@ import ( bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" backend_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/cs" groth16_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/groth16" "github.com/consensys/gnark/internal/backend/bls12-377/witness" @@ -61,7 +62,7 @@ func generateBls12377InnerProof(t *testing.T, vk *groth16_bls12377.VerifyingKey, // create a mock cs: knowing the preimage of a hash using mimc var circuit, w mimcCircuit - r1cs, err := frontend.Compile(ecc.BLS12_377, backend.GROTH16, &circuit) + r1cs, err := frontend.Compile(ecc.BLS12_377, r1cs.NewBuilder, &circuit) if err != nil { t.Fatal(err) } @@ -200,7 +201,7 @@ func BenchmarkCompile(b *testing.B) { var ccs frontend.CompiledConstraintSystem b.ResetTimer() for i := 0; i < b.N; i++ { - ccs, _ = frontend.Compile(ecc.BW6_761, backend.GROTH16, &circuit) + ccs, _ = frontend.Compile(ecc.BW6_761, r1cs.NewBuilder, &circuit) } b.Log(ccs.GetNbConstraints()) } diff --git a/std/groth16_bls24315/verifier_test.go b/std/groth16_bls24315/verifier_test.go index 95a71be3ec..0f01fb94fd 100644 --- a/std/groth16_bls24315/verifier_test.go +++ b/std/groth16_bls24315/verifier_test.go @@ -24,6 +24,7 @@ import ( bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" backend_bls24315 "github.com/consensys/gnark/internal/backend/bls24-315/cs" groth16_bls24315 "github.com/consensys/gnark/internal/backend/bls24-315/groth16" "github.com/consensys/gnark/internal/backend/bls24-315/witness" @@ -62,7 +63,7 @@ func generateBls24315InnerProof(t *testing.T, vk *groth16_bls24315.VerifyingKey, // create a mock cs: knowing the preimage of a hash using mimc var circuit, w mimcCircuit - r1cs, err := frontend.Compile(ecc.BLS24_315, backend.GROTH16, &circuit) + r1cs, err := frontend.Compile(ecc.BLS24_315, r1cs.NewBuilder, &circuit) if err != nil { t.Fatal(err) } @@ -201,7 +202,7 @@ func BenchmarkCompile(b *testing.B) { var ccs frontend.CompiledConstraintSystem b.ResetTimer() for i := 0; i < b.N; i++ { - ccs, _ = frontend.Compile(ecc.BW6_633, backend.GROTH16, &circuit) + ccs, _ = frontend.Compile(ecc.BW6_633, r1cs.NewBuilder, &circuit) } b.Log(ccs.GetNbConstraints()) } diff --git a/test/assert.go b/test/assert.go index 84b4efab2d..2926d9ca9b 100644 --- a/test/assert.go +++ b/test/assert.go @@ -30,6 +30,8 @@ import ( "github.com/consensys/gnark/backend/witness" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/compiled" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/stretchr/testify/require" ) @@ -397,14 +399,26 @@ func (assert *Assert) compile(circuit frontend.Circuit, curveID ecc.ID, backendI // TODO we may want to check that it was compiled with the same compile options here return ccs, nil } + + var newBuilder frontend.NewBuilder + + switch backendID { + case backend.GROTH16: + newBuilder = r1cs.NewBuilder + case backend.PLONK: + newBuilder = scs.NewBuilder + default: + panic("not implemented") + } + // else compile it and ensure it is deterministic - ccs, err := frontend.Compile(curveID, backendID, circuit, compileOpts...) + ccs, err := frontend.Compile(curveID, newBuilder, circuit, compileOpts...) // ccs, err := compiler.Compile(curveID, backendID, circuit, compileOpts...) if err != nil { return nil, err } - _ccs, err := frontend.Compile(curveID, backendID, circuit, compileOpts...) + _ccs, err := frontend.Compile(curveID, newBuilder, circuit, compileOpts...) // _ccs, err := compiler.Compile(curveID, backendID, circuit, compileOpts...) if err != nil { return nil, fmt.Errorf("%w: %v", ErrCompilationNotDeterministic, err) From 96eedb5d3b6027c9088b02615e7a4cf6ad6a7ec1 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 25 Feb 2022 11:28:23 -0600 Subject: [PATCH 08/20] refactor: preparing frontend.Compiler interface --- README.md | 2 +- circuitstats_test.go | 8 +- debug_test.go | 4 +- examples/plonk/main.go | 2 +- examples/serialization/main.go | 2 +- frontend/compile.go | 91 ++++++++++--------- frontend/cs/r1cs/conversion.go | 9 +- frontend/cs/r1cs/r1cs.go | 20 ++-- frontend/cs/scs/conversion.go | 9 +- frontend/cs/scs/sparse_r1cs.go | 20 ++-- internal/backend/bls12-377/cs/r1cs_test.go | 6 +- .../backend/bls12-377/groth16/groth16_test.go | 2 +- .../backend/bls12-377/plonk/plonk_test.go | 2 +- internal/backend/bls12-381/cs/r1cs_test.go | 6 +- .../backend/bls12-381/groth16/groth16_test.go | 2 +- .../backend/bls12-381/plonk/plonk_test.go | 2 +- internal/backend/bls24-315/cs/r1cs_test.go | 6 +- .../backend/bls24-315/groth16/groth16_test.go | 2 +- .../backend/bls24-315/plonk/plonk_test.go | 2 +- internal/backend/bn254/cs/r1cs_test.go | 6 +- .../backend/bn254/groth16/groth16_test.go | 2 +- internal/backend/bn254/plonk/plonk_test.go | 2 +- internal/backend/bw6-633/cs/r1cs_test.go | 6 +- .../backend/bw6-633/groth16/groth16_test.go | 2 +- internal/backend/bw6-633/plonk/plonk_test.go | 2 +- internal/backend/bw6-761/cs/r1cs_test.go | 6 +- .../backend/bw6-761/groth16/groth16_test.go | 2 +- internal/backend/bw6-761/plonk/plonk_test.go | 2 +- .../representations/tests/r1cs.go.tmpl | 6 +- .../zkpschemes/groth16/tests/groth16.go.tmpl | 2 +- .../zkpschemes/plonk/tests/plonk.go.tmpl | 2 +- std/algebra/fields_bls24315/e24_test.go | 10 +- std/algebra/sw_bls12377/g1_test.go | 8 +- std/algebra/sw_bls12377/g2_test.go | 6 +- std/algebra/sw_bls12377/pairing_test.go | 2 +- std/algebra/sw_bls24315/g1_test.go | 8 +- std/algebra/sw_bls24315/g2_test.go | 6 +- std/algebra/sw_bls24315/pairing_test.go | 2 +- .../twistededwards/bandersnatch/point_test.go | 12 +-- std/fiat-shamir/transcript_test.go | 2 +- std/groth16_bls12377/verifier_test.go | 4 +- std/groth16_bls24315/verifier_test.go | 4 +- test/assert.go | 6 +- 43 files changed, 168 insertions(+), 139 deletions(-) diff --git a/README.md b/README.md index 1f20c8b16d..d4fcf03e5b 100644 --- a/README.md +++ b/README.md @@ -83,7 +83,7 @@ func (circuit *CubicCircuit) Define(api frontend.API) error { // compiles our circuit into a R1CS var circuit CubicCircuit -ccs, err := frontend.Compile(ecc.BN254, r1cs.NewBuilder, &circuit) +ccs, err := frontend.Compile(ecc.BN254, r1cs.NewCompiler, &circuit) // groth16 zkSNARK: Setup pk, vk, err := groth16.Setup(ccs) diff --git a/circuitstats_test.go b/circuitstats_test.go index 12acf6da96..c5bd5b73a7 100644 --- a/circuitstats_test.go +++ b/circuitstats_test.go @@ -33,18 +33,18 @@ func TestCircuitStatistics(t *testing.T) { // copy the circuit now in case assert calls t.Parallel() tData := circuits.Circuits[k] assert.Run(func(assert *test.Assert) { - var newBuilder frontend.NewBuilder + var newCompiler frontend.NewCompiler switch backendID { case backend.GROTH16: - newBuilder = r1cs.NewBuilder + newCompiler = r1cs.NewCompiler case backend.PLONK: - newBuilder = scs.NewBuilder + newCompiler = scs.NewCompiler default: panic("not implemented") } - ccs, err := frontend.Compile(curve, newBuilder, tData.Circuit) + ccs, err := frontend.Compile(curve, newCompiler, tData.Circuit) assert.NoError(err) // ensure we didn't introduce regressions that make circuits less efficient diff --git a/debug_test.go b/debug_test.go index b28a2959ff..126cc7abfa 100644 --- a/debug_test.go +++ b/debug_test.go @@ -174,7 +174,7 @@ func TestTraceNotBoolean(t *testing.T) { } func getPlonkTrace(circuit, w frontend.Circuit) (string, error) { - ccs, err := frontend.Compile(ecc.BN254, scs.NewBuilder, circuit) + ccs, err := frontend.Compile(ecc.BN254, scs.NewCompiler, circuit) if err != nil { return "", err } @@ -198,7 +198,7 @@ func getPlonkTrace(circuit, w frontend.Circuit) (string, error) { } func getGroth16Trace(circuit, w frontend.Circuit) (string, error) { - ccs, err := frontend.Compile(ecc.BN254, r1cs.NewBuilder, circuit) + ccs, err := frontend.Compile(ecc.BN254, r1cs.NewCompiler, circuit) if err != nil { return "", err } diff --git a/examples/plonk/main.go b/examples/plonk/main.go index 99563f7a9a..31b809eff4 100644 --- a/examples/plonk/main.go +++ b/examples/plonk/main.go @@ -73,7 +73,7 @@ func main() { var circuit Circuit // // building the circuit... - ccs, err := frontend.Compile(ecc.BN254, scs.NewBuilder, &circuit) + ccs, err := frontend.Compile(ecc.BN254, scs.NewCompiler, &circuit) if err != nil { fmt.Println("circuit compilation error") } diff --git a/examples/serialization/main.go b/examples/serialization/main.go index 538784ca43..ca0d7b11cd 100644 --- a/examples/serialization/main.go +++ b/examples/serialization/main.go @@ -17,7 +17,7 @@ func main() { var circuit cubic.Circuit // compile a circuit - _r1cs, _ := frontend.Compile(ecc.BN254, r1cs.NewBuilder, &circuit) + _r1cs, _ := frontend.Compile(ecc.BN254, r1cs.NewCompiler, &circuit) // R1CS implements io.WriterTo and io.ReaderFrom var buf bytes.Buffer diff --git a/frontend/compile.go b/frontend/compile.go index 2af91d41ba..03d6434083 100644 --- a/frontend/compile.go +++ b/frontend/compile.go @@ -10,23 +10,27 @@ import ( "github.com/consensys/gnark/frontend/schema" ) -var tVariable reflect.Type +// Compiler represents a constraint system compiler +type Compiler interface { + // a compiler must implement frontend.API and will be injected in circuit.Define() + API -func init() { - tVariable = reflect.ValueOf(struct{ A Variable }{}).FieldByName("A").Type() -} + // Compile is called after circuit.Define() to produce a final IR (CompiledConstraintSystem) + Compile(opt CompileConfig) (CompiledConstraintSystem, error) -// Builder represents a constraint system builder -type Builder interface { - API - CheckVariables() error - NewPublicVariable(name string) Variable - NewSecretVariable(name string) Variable - Compile() (CompiledConstraintSystem, error) + // SetSchema is used internally by frontend.Compile to set the circuit schema SetSchema(*schema.Schema) + + // AddPublicVariable is called by the compiler when parsing the circuit schema. It panics if + // called inside circuit.Define() + AddPublicVariable(name string) Variable + + // AddSecretVariable is called by the compiler when parsing the circuit schema. It panics if + // called inside circuit.Define() + AddSecretVariable(name string) Variable } -type NewBuilder func(ecc.ID) (Builder, error) +type NewCompiler func(ecc.ID) (Compiler, error) // Compile will generate a ConstraintSystem from the given circuit // @@ -46,40 +50,33 @@ type NewBuilder func(ecc.ID) (Builder, error) // // initialCapacity is an optional parameter that reserves memory in slices // it should be set to the estimated number of constraints in the circuit, if known. -func Compile(curveID ecc.ID, newBuilder NewBuilder, circuit Circuit, opts ...CompileOption) (CompiledConstraintSystem, error) { - // setup option - opt := compileConfig{} +func Compile(curveID ecc.ID, newCompiler NewCompiler, circuit Circuit, opts ...CompileOption) (CompiledConstraintSystem, error) { + // parse options + opt := CompileConfig{} for _, o := range opts { if err := o(&opt); err != nil { return nil, fmt.Errorf("apply option: %w", err) } } - builder, err := newBuilder(curveID) + // instantiate new compiler + compiler, err := newCompiler(curveID) if err != nil { - return nil, fmt.Errorf("new builder: %w", err) + return nil, fmt.Errorf("new compiler: %w", err) } - if err = bootstrap(builder, circuit); err != nil { - return nil, fmt.Errorf("bootstrap: %w", err) + // parse the circuit builds a schema of the circuit + // and call circuit.Define() method to initialize a list of constraints in the compiler + if err = parseCircuit(compiler, circuit); err != nil { + return nil, fmt.Errorf("parse circuit: %w", err) } - // ensure all inputs and hints are constrained - if !opt.ignoreUnconstrainedInputs { - if err := builder.CheckVariables(); err != nil { - return nil, err - } - } - - ccs, err := builder.Compile() - if err != nil { - return nil, fmt.Errorf("compile system: %w", err) - } - return ccs, nil + // compile the circuit into its final form + return compiler.Compile(opt) } -func bootstrap(builder Builder, circuit Circuit) (err error) { +func parseCircuit(compiler Compiler, circuit Circuit) (err error) { // ensure circuit.Define has pointer receiver if reflect.ValueOf(circuit).Kind() != reflect.Ptr { return errors.New("frontend.Circuit methods must be defined on pointer receiver") @@ -91,9 +88,9 @@ func bootstrap(builder Builder, circuit Circuit) (err error) { if tInput.CanSet() { switch visibility { case schema.Secret: - tInput.Set(reflect.ValueOf(builder.NewSecretVariable(name))) + tInput.Set(reflect.ValueOf(compiler.AddSecretVariable(name))) case schema.Public: - tInput.Set(reflect.ValueOf(builder.NewPublicVariable(name))) + tInput.Set(reflect.ValueOf(compiler.AddPublicVariable(name))) case schema.Unset: return errors.New("can't set val " + name + " visibility is unset") } @@ -108,7 +105,7 @@ func bootstrap(builder Builder, circuit Circuit) (err error) { if err != nil { return err } - builder.SetSchema(s) + compiler.SetSchema(s) // recover from panics to print user-friendlier messages defer func() { @@ -118,7 +115,7 @@ func bootstrap(builder Builder, circuit Circuit) (err error) { }() // call Define() to fill in the Constraints - if err = circuit.Define(builder); err != nil { + if err = circuit.Define(compiler); err != nil { return fmt.Errorf("define circuit: %w", err) } @@ -128,19 +125,19 @@ func bootstrap(builder Builder, circuit Circuit) (err error) { // CompileOption defines option for altering the behaviour of the Compile // method. See the descriptions of the functions returning instances of this // type for available options. -type CompileOption func(opt *compileConfig) error +type CompileOption func(opt *CompileConfig) error -type compileConfig struct { - capacity int - ignoreUnconstrainedInputs bool +type CompileConfig struct { + Capacity int + IgnoreUnconstrainedInputs bool } // WithCapacity is a compile option that specifies the estimated capacity needed // for internal variables and constraints. If not set, then the initial capacity // is 0 and is dynamically allocated as needed. func WithCapacity(capacity int) CompileOption { - return func(opt *compileConfig) error { - opt.capacity = capacity + return func(opt *CompileConfig) error { + opt.Capacity = capacity return nil } } @@ -153,8 +150,14 @@ func WithCapacity(capacity int) CompileOption { // production settings as it means that there is a potential error in the // circuit definition or that it is possible to optimize witness size. func IgnoreUnconstrainedInputs() CompileOption { - return func(opt *compileConfig) error { - opt.ignoreUnconstrainedInputs = true + return func(opt *CompileConfig) error { + opt.IgnoreUnconstrainedInputs = true return nil } } + +var tVariable reflect.Type + +func init() { + tVariable = reflect.ValueOf(struct{ A Variable }{}).FieldByName("A").Type() +} diff --git a/frontend/cs/r1cs/conversion.go b/frontend/cs/r1cs/conversion.go index 3d9be7da55..424251f94b 100644 --- a/frontend/cs/r1cs/conversion.go +++ b/frontend/cs/r1cs/conversion.go @@ -30,7 +30,14 @@ import ( ) // Compile constructs a rank-1 constraint sytem -func (cs *r1CS) Compile() (frontend.CompiledConstraintSystem, error) { +func (cs *r1CS) Compile(opt frontend.CompileConfig) (frontend.CompiledConstraintSystem, error) { + + // ensure all inputs and hints are constrained + if !opt.IgnoreUnconstrainedInputs { + if err := cs.checkVariables(); err != nil { + return nil, err + } + } // wires = public wires | secret wires | internal wires diff --git a/frontend/cs/r1cs/r1cs.go b/frontend/cs/r1cs/r1cs.go index ab82970581..23bd376f08 100644 --- a/frontend/cs/r1cs/r1cs.go +++ b/frontend/cs/r1cs/r1cs.go @@ -31,7 +31,7 @@ import ( "github.com/consensys/gnark/frontend/schema" ) -func NewBuilder(curve ecc.ID) (frontend.Builder, error) { +func NewCompiler(curve ecc.ID) (frontend.Compiler, error) { return newR1CS(curve), nil } @@ -56,7 +56,7 @@ func newR1CS(curveID ecc.ID, initialCapacity ...int) *r1CS { MHints: make(map[int]*compiled.Hint), }, Constraints: make([]compiled.R1C, 0, capacity), - builder: cs.NewBuilder(), + builder: cs.NewCompiler(), } system.builder.Coeffs[compiled.CoeffIdZero].SetInt64(0) @@ -92,8 +92,11 @@ func (system *r1CS) newInternalVariable() compiled.Variable { } } -// NewPublicVariable creates a new public Variable -func (system *r1CS) NewPublicVariable(name string) frontend.Variable { +// AddPublicVariable creates a new public Variable +func (system *r1CS) AddPublicVariable(name string) frontend.Variable { + if system.Schema != nil { + panic("do not call AddPublicVariable in circuit.Define()") + } t := false idx := len(system.Public) system.Public = append(system.Public, name) @@ -104,8 +107,11 @@ func (system *r1CS) NewPublicVariable(name string) frontend.Variable { return res } -// NewSecretVariable creates a new secret Variable -func (system *r1CS) NewSecretVariable(name string) frontend.Variable { +// AddSecretVariable creates a new secret Variable +func (system *r1CS) AddSecretVariable(name string) frontend.Variable { + if system.Schema != nil { + panic("do not call AddSecretVariable in circuit.Define()") + } t := false idx := len(system.Secret) system.Secret = append(system.Secret, name) @@ -210,7 +216,7 @@ func (system *r1CS) markBoolean(v compiled.Variable) bool { // // 1. checks that all user inputs are referenced in at least one constraint // 2. checks that all hints are constrained -func (system *r1CS) CheckVariables() error { +func (system *r1CS) checkVariables() error { // TODO @gbotrel add unit test for that. diff --git a/frontend/cs/scs/conversion.go b/frontend/cs/scs/conversion.go index f5bb6bde28..a0125afd79 100644 --- a/frontend/cs/scs/conversion.go +++ b/frontend/cs/scs/conversion.go @@ -29,7 +29,14 @@ import ( bw6761r1cs "github.com/consensys/gnark/internal/backend/bw6-761/cs" ) -func (cs *sparseR1CS) Compile() (frontend.CompiledConstraintSystem, error) { +func (cs *sparseR1CS) Compile(opt frontend.CompileConfig) (frontend.CompiledConstraintSystem, error) { + + // ensure all inputs and hints are constrained + if !opt.IgnoreUnconstrainedInputs { + if err := cs.checkVariables(); err != nil { + return nil, err + } + } res := compiled.SparseR1CS{ ConstraintSystem: cs.ConstraintSystem, diff --git a/frontend/cs/scs/sparse_r1cs.go b/frontend/cs/scs/sparse_r1cs.go index 0afb18aa83..f8322eeb92 100644 --- a/frontend/cs/scs/sparse_r1cs.go +++ b/frontend/cs/scs/sparse_r1cs.go @@ -31,7 +31,7 @@ import ( "github.com/consensys/gnark/frontend/schema" ) -func NewBuilder(curve ecc.ID) (frontend.Builder, error) { +func NewCompiler(curve ecc.ID) (frontend.Compiler, error) { return newSparseR1CS(curve), nil } @@ -56,7 +56,7 @@ func newSparseR1CS(curveID ecc.ID, initialCapacity ...int) *sparseR1CS { MHints: make(map[int]*compiled.Hint), }, Constraints: make([]compiled.SparseR1C, 0, capacity), - builder: cs.NewBuilder(), + builder: cs.NewCompiler(), } system.builder.Coeffs[compiled.CoeffIdZero].SetInt64(0) @@ -109,15 +109,21 @@ func (system *sparseR1CS) newInternalVariable() compiled.Term { return compiled.Pack(idx, compiled.CoeffIdOne, schema.Internal) } -// NewPublicVariable creates a new Public Variable -func (system *sparseR1CS) NewPublicVariable(name string) frontend.Variable { +// AddPublicVariable creates a new Public Variable +func (system *sparseR1CS) AddPublicVariable(name string) frontend.Variable { + if system.Schema != nil { + panic("do not call AddPublicVariable in circuit.Define()") + } idx := len(system.Public) system.Public = append(system.Public, name) return compiled.Pack(idx, compiled.CoeffIdOne, schema.Public) } -// NewPublicVariable creates a new Secret Variable -func (system *sparseR1CS) NewSecretVariable(name string) frontend.Variable { +// AddSecretVariable creates a new Secret Variable +func (system *sparseR1CS) AddSecretVariable(name string) frontend.Variable { + if system.Schema != nil { + panic("do not call AddSecretVariable in circuit.Define()") + } idx := len(system.Secret) system.Secret = append(system.Secret, name) return compiled.Pack(idx, compiled.CoeffIdOne, schema.Secret) @@ -170,7 +176,7 @@ func (system *sparseR1CS) markBoolean(t compiled.Term) { // // 1. checks that all user inputs are referenced in at least one constraint // 2. checks that all hints are constrained -func (system *sparseR1CS) CheckVariables() error { +func (system *sparseR1CS) checkVariables() error { // TODO @gbotrel add unit test for that. diff --git a/internal/backend/bls12-377/cs/r1cs_test.go b/internal/backend/bls12-377/cs/r1cs_test.go index c8b2113d59..8af814bc99 100644 --- a/internal/backend/bls12-377/cs/r1cs_test.go +++ b/internal/backend/bls12-377/cs/r1cs_test.go @@ -36,7 +36,7 @@ func TestSerialization(t *testing.T) { t.Run(name, func(t *testing.T) { tc := circuits.Circuits[name] - r1cs1, err := frontend.Compile(ecc.BLS12_377, r1cs.NewBuilder, tc.Circuit) + r1cs1, err := frontend.Compile(ecc.BLS12_377, r1cs.NewCompiler, tc.Circuit) if err != nil { t.Fatal(err) } @@ -45,7 +45,7 @@ func TestSerialization(t *testing.T) { } // copmpile a second time to ensure determinism - r1cs2, err := frontend.Compile(ecc.BLS12_377, r1cs.NewBuilder, tc.Circuit) + r1cs2, err := frontend.Compile(ecc.BLS12_377, r1cs.NewCompiler, tc.Circuit) if err != nil { t.Fatal(err) } @@ -134,7 +134,7 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { var c circuit - ccs, err := frontend.Compile(ecc.BLS12_377, r1cs.NewBuilder, &c) + ccs, err := frontend.Compile(ecc.BLS12_377, r1cs.NewCompiler, &c) if err != nil { b.Fatal(err) } diff --git a/internal/backend/bls12-377/groth16/groth16_test.go b/internal/backend/bls12-377/groth16/groth16_test.go index 9d377c713f..ebcfb09f76 100644 --- a/internal/backend/bls12-377/groth16/groth16_test.go +++ b/internal/backend/bls12-377/groth16/groth16_test.go @@ -58,7 +58,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit) { circuit := refCircuit{ nbConstraints: nbConstraints, } - r1cs, err := frontend.Compile(curve.ID, r1cs.NewBuilder, &circuit) + r1cs, err := frontend.Compile(curve.ID, r1cs.NewCompiler, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bls12-377/plonk/plonk_test.go b/internal/backend/bls12-377/plonk/plonk_test.go index 192ecc0324..2076347fd5 100644 --- a/internal/backend/bls12-377/plonk/plonk_test.go +++ b/internal/backend/bls12-377/plonk/plonk_test.go @@ -62,7 +62,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit, *k circuit := refCircuit{ nbConstraints: nbConstraints, } - ccs, err := frontend.Compile(curve.ID, scs.NewBuilder, &circuit) + ccs, err := frontend.Compile(curve.ID, scs.NewCompiler, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bls12-381/cs/r1cs_test.go b/internal/backend/bls12-381/cs/r1cs_test.go index 57f05cc18b..71ba350ad0 100644 --- a/internal/backend/bls12-381/cs/r1cs_test.go +++ b/internal/backend/bls12-381/cs/r1cs_test.go @@ -36,7 +36,7 @@ func TestSerialization(t *testing.T) { t.Run(name, func(t *testing.T) { tc := circuits.Circuits[name] - r1cs1, err := frontend.Compile(ecc.BLS12_381, r1cs.NewBuilder, tc.Circuit) + r1cs1, err := frontend.Compile(ecc.BLS12_381, r1cs.NewCompiler, tc.Circuit) if err != nil { t.Fatal(err) } @@ -45,7 +45,7 @@ func TestSerialization(t *testing.T) { } // copmpile a second time to ensure determinism - r1cs2, err := frontend.Compile(ecc.BLS12_381, r1cs.NewBuilder, tc.Circuit) + r1cs2, err := frontend.Compile(ecc.BLS12_381, r1cs.NewCompiler, tc.Circuit) if err != nil { t.Fatal(err) } @@ -134,7 +134,7 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { var c circuit - ccs, err := frontend.Compile(ecc.BLS12_381, r1cs.NewBuilder, &c) + ccs, err := frontend.Compile(ecc.BLS12_381, r1cs.NewCompiler, &c) if err != nil { b.Fatal(err) } diff --git a/internal/backend/bls12-381/groth16/groth16_test.go b/internal/backend/bls12-381/groth16/groth16_test.go index 3bf5084315..1c229f9ac8 100644 --- a/internal/backend/bls12-381/groth16/groth16_test.go +++ b/internal/backend/bls12-381/groth16/groth16_test.go @@ -58,7 +58,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit) { circuit := refCircuit{ nbConstraints: nbConstraints, } - r1cs, err := frontend.Compile(curve.ID, r1cs.NewBuilder, &circuit) + r1cs, err := frontend.Compile(curve.ID, r1cs.NewCompiler, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bls12-381/plonk/plonk_test.go b/internal/backend/bls12-381/plonk/plonk_test.go index c653d99b0f..712add07b1 100644 --- a/internal/backend/bls12-381/plonk/plonk_test.go +++ b/internal/backend/bls12-381/plonk/plonk_test.go @@ -62,7 +62,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit, *k circuit := refCircuit{ nbConstraints: nbConstraints, } - ccs, err := frontend.Compile(curve.ID, scs.NewBuilder, &circuit) + ccs, err := frontend.Compile(curve.ID, scs.NewCompiler, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bls24-315/cs/r1cs_test.go b/internal/backend/bls24-315/cs/r1cs_test.go index 85ea52bff8..1fb373ed46 100644 --- a/internal/backend/bls24-315/cs/r1cs_test.go +++ b/internal/backend/bls24-315/cs/r1cs_test.go @@ -36,7 +36,7 @@ func TestSerialization(t *testing.T) { t.Run(name, func(t *testing.T) { tc := circuits.Circuits[name] - r1cs1, err := frontend.Compile(ecc.BLS24_315, r1cs.NewBuilder, tc.Circuit) + r1cs1, err := frontend.Compile(ecc.BLS24_315, r1cs.NewCompiler, tc.Circuit) if err != nil { t.Fatal(err) } @@ -45,7 +45,7 @@ func TestSerialization(t *testing.T) { } // copmpile a second time to ensure determinism - r1cs2, err := frontend.Compile(ecc.BLS24_315, r1cs.NewBuilder, tc.Circuit) + r1cs2, err := frontend.Compile(ecc.BLS24_315, r1cs.NewCompiler, tc.Circuit) if err != nil { t.Fatal(err) } @@ -134,7 +134,7 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { var c circuit - ccs, err := frontend.Compile(ecc.BLS24_315, r1cs.NewBuilder, &c) + ccs, err := frontend.Compile(ecc.BLS24_315, r1cs.NewCompiler, &c) if err != nil { b.Fatal(err) } diff --git a/internal/backend/bls24-315/groth16/groth16_test.go b/internal/backend/bls24-315/groth16/groth16_test.go index de108512a5..ed67ed605d 100644 --- a/internal/backend/bls24-315/groth16/groth16_test.go +++ b/internal/backend/bls24-315/groth16/groth16_test.go @@ -58,7 +58,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit) { circuit := refCircuit{ nbConstraints: nbConstraints, } - r1cs, err := frontend.Compile(curve.ID, r1cs.NewBuilder, &circuit) + r1cs, err := frontend.Compile(curve.ID, r1cs.NewCompiler, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bls24-315/plonk/plonk_test.go b/internal/backend/bls24-315/plonk/plonk_test.go index 44dcbfe868..dcc755d9d4 100644 --- a/internal/backend/bls24-315/plonk/plonk_test.go +++ b/internal/backend/bls24-315/plonk/plonk_test.go @@ -62,7 +62,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit, *k circuit := refCircuit{ nbConstraints: nbConstraints, } - ccs, err := frontend.Compile(curve.ID, scs.NewBuilder, &circuit) + ccs, err := frontend.Compile(curve.ID, scs.NewCompiler, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bn254/cs/r1cs_test.go b/internal/backend/bn254/cs/r1cs_test.go index 2cbd8fb574..7c9806da35 100644 --- a/internal/backend/bn254/cs/r1cs_test.go +++ b/internal/backend/bn254/cs/r1cs_test.go @@ -36,7 +36,7 @@ func TestSerialization(t *testing.T) { t.Run(name, func(t *testing.T) { tc := circuits.Circuits[name] - r1cs1, err := frontend.Compile(ecc.BN254, r1cs.NewBuilder, tc.Circuit) + r1cs1, err := frontend.Compile(ecc.BN254, r1cs.NewCompiler, tc.Circuit) if err != nil { t.Fatal(err) } @@ -45,7 +45,7 @@ func TestSerialization(t *testing.T) { } // copmpile a second time to ensure determinism - r1cs2, err := frontend.Compile(ecc.BN254, r1cs.NewBuilder, tc.Circuit) + r1cs2, err := frontend.Compile(ecc.BN254, r1cs.NewCompiler, tc.Circuit) if err != nil { t.Fatal(err) } @@ -134,7 +134,7 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { var c circuit - ccs, err := frontend.Compile(ecc.BN254, r1cs.NewBuilder, &c) + ccs, err := frontend.Compile(ecc.BN254, r1cs.NewCompiler, &c) if err != nil { b.Fatal(err) } diff --git a/internal/backend/bn254/groth16/groth16_test.go b/internal/backend/bn254/groth16/groth16_test.go index fc4735fe87..1fc0674145 100644 --- a/internal/backend/bn254/groth16/groth16_test.go +++ b/internal/backend/bn254/groth16/groth16_test.go @@ -58,7 +58,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit) { circuit := refCircuit{ nbConstraints: nbConstraints, } - r1cs, err := frontend.Compile(curve.ID, r1cs.NewBuilder, &circuit) + r1cs, err := frontend.Compile(curve.ID, r1cs.NewCompiler, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bn254/plonk/plonk_test.go b/internal/backend/bn254/plonk/plonk_test.go index 67bc45b90c..5860b7b3a7 100644 --- a/internal/backend/bn254/plonk/plonk_test.go +++ b/internal/backend/bn254/plonk/plonk_test.go @@ -62,7 +62,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit, *k circuit := refCircuit{ nbConstraints: nbConstraints, } - ccs, err := frontend.Compile(curve.ID, scs.NewBuilder, &circuit) + ccs, err := frontend.Compile(curve.ID, scs.NewCompiler, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bw6-633/cs/r1cs_test.go b/internal/backend/bw6-633/cs/r1cs_test.go index 2f1ff61692..5287373c76 100644 --- a/internal/backend/bw6-633/cs/r1cs_test.go +++ b/internal/backend/bw6-633/cs/r1cs_test.go @@ -36,7 +36,7 @@ func TestSerialization(t *testing.T) { t.Run(name, func(t *testing.T) { tc := circuits.Circuits[name] - r1cs1, err := frontend.Compile(ecc.BW6_633, r1cs.NewBuilder, tc.Circuit) + r1cs1, err := frontend.Compile(ecc.BW6_633, r1cs.NewCompiler, tc.Circuit) if err != nil { t.Fatal(err) } @@ -45,7 +45,7 @@ func TestSerialization(t *testing.T) { } // copmpile a second time to ensure determinism - r1cs2, err := frontend.Compile(ecc.BW6_633, r1cs.NewBuilder, tc.Circuit) + r1cs2, err := frontend.Compile(ecc.BW6_633, r1cs.NewCompiler, tc.Circuit) if err != nil { t.Fatal(err) } @@ -134,7 +134,7 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { var c circuit - ccs, err := frontend.Compile(ecc.BW6_633, r1cs.NewBuilder, &c) + ccs, err := frontend.Compile(ecc.BW6_633, r1cs.NewCompiler, &c) if err != nil { b.Fatal(err) } diff --git a/internal/backend/bw6-633/groth16/groth16_test.go b/internal/backend/bw6-633/groth16/groth16_test.go index 188d648bdf..913eefffec 100644 --- a/internal/backend/bw6-633/groth16/groth16_test.go +++ b/internal/backend/bw6-633/groth16/groth16_test.go @@ -58,7 +58,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit) { circuit := refCircuit{ nbConstraints: nbConstraints, } - r1cs, err := frontend.Compile(curve.ID, r1cs.NewBuilder, &circuit) + r1cs, err := frontend.Compile(curve.ID, r1cs.NewCompiler, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bw6-633/plonk/plonk_test.go b/internal/backend/bw6-633/plonk/plonk_test.go index 93c9b44350..b534add524 100644 --- a/internal/backend/bw6-633/plonk/plonk_test.go +++ b/internal/backend/bw6-633/plonk/plonk_test.go @@ -62,7 +62,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit, *k circuit := refCircuit{ nbConstraints: nbConstraints, } - ccs, err := frontend.Compile(curve.ID, scs.NewBuilder, &circuit) + ccs, err := frontend.Compile(curve.ID, scs.NewCompiler, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bw6-761/cs/r1cs_test.go b/internal/backend/bw6-761/cs/r1cs_test.go index 71763a400c..b17a3f0f16 100644 --- a/internal/backend/bw6-761/cs/r1cs_test.go +++ b/internal/backend/bw6-761/cs/r1cs_test.go @@ -40,7 +40,7 @@ func TestSerialization(t *testing.T) { return } - r1cs1, err := frontend.Compile(ecc.BW6_761, r1cs.NewBuilder, tc.Circuit) + r1cs1, err := frontend.Compile(ecc.BW6_761, r1cs.NewCompiler, tc.Circuit) if err != nil { t.Fatal(err) } @@ -49,7 +49,7 @@ func TestSerialization(t *testing.T) { } // copmpile a second time to ensure determinism - r1cs2, err := frontend.Compile(ecc.BW6_761, r1cs.NewBuilder, tc.Circuit) + r1cs2, err := frontend.Compile(ecc.BW6_761, r1cs.NewCompiler, tc.Circuit) if err != nil { t.Fatal(err) } @@ -138,7 +138,7 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { var c circuit - ccs, err := frontend.Compile(ecc.BW6_761, r1cs.NewBuilder, &c) + ccs, err := frontend.Compile(ecc.BW6_761, r1cs.NewCompiler, &c) if err != nil { b.Fatal(err) } diff --git a/internal/backend/bw6-761/groth16/groth16_test.go b/internal/backend/bw6-761/groth16/groth16_test.go index eebefa3488..2077099253 100644 --- a/internal/backend/bw6-761/groth16/groth16_test.go +++ b/internal/backend/bw6-761/groth16/groth16_test.go @@ -58,7 +58,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit) { circuit := refCircuit{ nbConstraints: nbConstraints, } - r1cs, err := frontend.Compile(curve.ID, r1cs.NewBuilder, &circuit) + r1cs, err := frontend.Compile(curve.ID, r1cs.NewCompiler, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bw6-761/plonk/plonk_test.go b/internal/backend/bw6-761/plonk/plonk_test.go index f44eb8982b..7e9b4fafd9 100644 --- a/internal/backend/bw6-761/plonk/plonk_test.go +++ b/internal/backend/bw6-761/plonk/plonk_test.go @@ -62,7 +62,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit, *k circuit := refCircuit{ nbConstraints: nbConstraints, } - ccs, err := frontend.Compile(curve.ID, scs.NewBuilder, &circuit) + ccs, err := frontend.Compile(curve.ID, scs.NewCompiler, &circuit) if err != nil { panic(err) } diff --git a/internal/generator/backend/template/representations/tests/r1cs.go.tmpl b/internal/generator/backend/template/representations/tests/r1cs.go.tmpl index 2cce6f0613..cb399f03e5 100644 --- a/internal/generator/backend/template/representations/tests/r1cs.go.tmpl +++ b/internal/generator/backend/template/representations/tests/r1cs.go.tmpl @@ -24,7 +24,7 @@ func TestSerialization(t *testing.T) { } {{end}} - r1cs1, err := frontend.Compile(ecc.{{ .CurveID }}, r1cs.NewBuilder, tc.Circuit) + r1cs1, err := frontend.Compile(ecc.{{ .CurveID }}, r1cs.NewCompiler, tc.Circuit) if err != nil { t.Fatal(err) } @@ -33,7 +33,7 @@ func TestSerialization(t *testing.T) { } // copmpile a second time to ensure determinism - r1cs2, err := frontend.Compile(ecc.{{ .CurveID }},r1cs.NewBuilder, tc.Circuit) + r1cs2, err := frontend.Compile(ecc.{{ .CurveID }},r1cs.NewCompiler, tc.Circuit) if err != nil { t.Fatal(err) } @@ -123,7 +123,7 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { var c circuit - ccs, err := frontend.Compile(ecc.{{ .CurveID }},r1cs.NewBuilder, &c) + ccs, err := frontend.Compile(ecc.{{ .CurveID }},r1cs.NewCompiler, &c) if err != nil { b.Fatal(err) } diff --git a/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.go.tmpl index c6d4305a4b..4b0abd9e7c 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.go.tmpl @@ -38,7 +38,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit) { circuit := refCircuit{ nbConstraints: nbConstraints, } - r1cs, err := frontend.Compile(curve.ID,r1cs.NewBuilder, &circuit) + r1cs, err := frontend.Compile(curve.ID,r1cs.NewCompiler, &circuit) if err != nil { panic(err) } diff --git a/internal/generator/backend/template/zkpschemes/plonk/tests/plonk.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/tests/plonk.go.tmpl index 140e5854f4..d53aff5dd5 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/tests/plonk.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/tests/plonk.go.tmpl @@ -43,7 +43,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit, *k circuit := refCircuit{ nbConstraints: nbConstraints, } - ccs, err := frontend.Compile(curve.ID, scs.NewBuilder, &circuit) + ccs, err := frontend.Compile(curve.ID, scs.NewCompiler, &circuit) if err != nil { panic(err) } diff --git a/std/algebra/fields_bls24315/e24_test.go b/std/algebra/fields_bls24315/e24_test.go index 1fc7c69764..38ce758adb 100644 --- a/std/algebra/fields_bls24315/e24_test.go +++ b/std/algebra/fields_bls24315/e24_test.go @@ -414,7 +414,7 @@ func BenchmarkMulE24(b *testing.B) { var c fp24Mul b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewBuilder, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewCompiler, &c) } }) @@ -425,7 +425,7 @@ func BenchmarkSquareE24(b *testing.B) { var c fp24Square b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewBuilder, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewCompiler, &c) } }) @@ -436,7 +436,7 @@ func BenchmarkInverseE24(b *testing.B) { var c fp24Inverse b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewBuilder, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewCompiler, &c) } }) @@ -447,7 +447,7 @@ func BenchmarkConjugateE24(b *testing.B) { var c fp24Conjugate b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewBuilder, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewCompiler, &c) } }) @@ -458,7 +458,7 @@ func BenchmarkMulBy034E24(b *testing.B) { var c fp24MulBy034 b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewBuilder, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewCompiler, &c) } }) diff --git a/std/algebra/sw_bls12377/g1_test.go b/std/algebra/sw_bls12377/g1_test.go index 8736284061..878aa9f3bb 100644 --- a/std/algebra/sw_bls12377/g1_test.go +++ b/std/algebra/sw_bls12377/g1_test.go @@ -392,7 +392,7 @@ func BenchmarkConstScalarMulG1(b *testing.B) { c.R = r b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_761, r1cs.NewBuilder, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_761, r1cs.NewCompiler, &c) } }) @@ -400,7 +400,7 @@ func BenchmarkConstScalarMulG1(b *testing.B) { b.Run("plonk", func(b *testing.B) { var err error for i := 0; i < b.N; i++ { - ccsBench, err = frontend.Compile(ecc.BW6_761, scs.NewBuilder, &c) + ccsBench, err = frontend.Compile(ecc.BW6_761, scs.NewCompiler, &c) if err != nil { b.Fatal(err) } @@ -421,7 +421,7 @@ func BenchmarkVarScalarMulG1(b *testing.B) { c.R = r b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_761, r1cs.NewBuilder, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_761, r1cs.NewCompiler, &c) } }) @@ -429,7 +429,7 @@ func BenchmarkVarScalarMulG1(b *testing.B) { b.Run("plonk", func(b *testing.B) { var err error for i := 0; i < b.N; i++ { - ccsBench, err = frontend.Compile(ecc.BW6_761, scs.NewBuilder, &c) + ccsBench, err = frontend.Compile(ecc.BW6_761, scs.NewCompiler, &c) if err != nil { b.Fatal(err) } diff --git a/std/algebra/sw_bls12377/g2_test.go b/std/algebra/sw_bls12377/g2_test.go index fd87276f8e..46cb6ed6d9 100644 --- a/std/algebra/sw_bls12377/g2_test.go +++ b/std/algebra/sw_bls12377/g2_test.go @@ -272,7 +272,7 @@ func BenchmarkDoubleAffineG2(b *testing.B) { var c g2DoubleAffine b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_761, r1cs.NewBuilder, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_761, r1cs.NewCompiler, &c) } }) @@ -283,7 +283,7 @@ func BenchmarkAddAssignAffineG2(b *testing.B) { var c g2AddAssignAffine b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_761, r1cs.NewBuilder, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_761, r1cs.NewCompiler, &c) } }) @@ -294,7 +294,7 @@ func BenchmarkDoubleAndAddAffineG2(b *testing.B) { var c g2DoubleAndAddAffine b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_761, r1cs.NewBuilder, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_761, r1cs.NewCompiler, &c) } }) diff --git a/std/algebra/sw_bls12377/pairing_test.go b/std/algebra/sw_bls12377/pairing_test.go index 645a5e0307..9b88ff0bd8 100644 --- a/std/algebra/sw_bls12377/pairing_test.go +++ b/std/algebra/sw_bls12377/pairing_test.go @@ -200,7 +200,7 @@ func BenchmarkPairing(b *testing.B) { var c pairingBLS377 b.ResetTimer() for i := 0; i < b.N; i++ { - frontend.Compile(ecc.BW6_761, scs.NewBuilder, &c) + frontend.Compile(ecc.BW6_761, scs.NewCompiler, &c) } // ccsBench, _ = compiler.Compile(ecc.BW6_761, backend.GROTH16, &c) // b.Log("groth16", ccsBench.GetNbConstraints()) diff --git a/std/algebra/sw_bls24315/g1_test.go b/std/algebra/sw_bls24315/g1_test.go index c2a4ef314b..577865a75e 100644 --- a/std/algebra/sw_bls24315/g1_test.go +++ b/std/algebra/sw_bls24315/g1_test.go @@ -392,7 +392,7 @@ func BenchmarkConstScalarMulG1(b *testing.B) { c.R = r b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewBuilder, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewCompiler, &c) } }) @@ -400,7 +400,7 @@ func BenchmarkConstScalarMulG1(b *testing.B) { b.Run("plonk", func(b *testing.B) { var err error for i := 0; i < b.N; i++ { - ccsBench, err = frontend.Compile(ecc.BW6_633, scs.NewBuilder, &c) + ccsBench, err = frontend.Compile(ecc.BW6_633, scs.NewCompiler, &c) if err != nil { b.Fatal(err) } @@ -421,7 +421,7 @@ func BenchmarkVarScalarMulG1(b *testing.B) { c.R = r b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewBuilder, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewCompiler, &c) } }) @@ -429,7 +429,7 @@ func BenchmarkVarScalarMulG1(b *testing.B) { b.Run("plonk", func(b *testing.B) { var err error for i := 0; i < b.N; i++ { - ccsBench, err = frontend.Compile(ecc.BW6_633, scs.NewBuilder, &c) + ccsBench, err = frontend.Compile(ecc.BW6_633, scs.NewCompiler, &c) if err != nil { b.Fatal(err) } diff --git a/std/algebra/sw_bls24315/g2_test.go b/std/algebra/sw_bls24315/g2_test.go index 6d5bf22e86..4fb2c941bb 100644 --- a/std/algebra/sw_bls24315/g2_test.go +++ b/std/algebra/sw_bls24315/g2_test.go @@ -272,7 +272,7 @@ func BenchmarkDoubleAffineG2(b *testing.B) { var c g2DoubleAffine b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewBuilder, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewCompiler, &c) } }) @@ -283,7 +283,7 @@ func BenchmarkAddAssignAffineG2(b *testing.B) { var c g2AddAssignAffine b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewBuilder, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewCompiler, &c) } }) @@ -294,7 +294,7 @@ func BenchmarkDoubleAndAddAffineG2(b *testing.B) { var c g2DoubleAndAddAffine b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewBuilder, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewCompiler, &c) } }) diff --git a/std/algebra/sw_bls24315/pairing_test.go b/std/algebra/sw_bls24315/pairing_test.go index 6b53f37178..ab3d16fd97 100644 --- a/std/algebra/sw_bls24315/pairing_test.go +++ b/std/algebra/sw_bls24315/pairing_test.go @@ -214,6 +214,6 @@ func TestTriplePairingBLS24315(t *testing.T) { func BenchmarkPairing(b *testing.B) { var c pairingBLS24315 - ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewBuilder, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewCompiler, &c) b.Log("groth16", ccsBench.GetNbConstraints()) } diff --git a/std/algebra/twistededwards/bandersnatch/point_test.go b/std/algebra/twistededwards/bandersnatch/point_test.go index c02ec8f7d7..94af797d4e 100644 --- a/std/algebra/twistededwards/bandersnatch/point_test.go +++ b/std/algebra/twistededwards/bandersnatch/point_test.go @@ -364,36 +364,36 @@ func TestNeg(t *testing.T) { // Bench func BenchmarkDouble(b *testing.B) { var c double - ccsBench, _ := frontend.Compile(ecc.BLS12_381, r1cs.NewBuilder, &c) + ccsBench, _ := frontend.Compile(ecc.BLS12_381, r1cs.NewCompiler, &c) b.Log("groth16", ccsBench.GetNbConstraints()) } func BenchmarkAddGeneric(b *testing.B) { var c addGeneric - ccsBench, _ := frontend.Compile(ecc.BLS12_381, r1cs.NewBuilder, &c) + ccsBench, _ := frontend.Compile(ecc.BLS12_381, r1cs.NewCompiler, &c) b.Log("groth16", ccsBench.GetNbConstraints()) } func BenchmarkAddFixedPoint(b *testing.B) { var c add - ccsBench, _ := frontend.Compile(ecc.BLS12_381, r1cs.NewBuilder, &c) + ccsBench, _ := frontend.Compile(ecc.BLS12_381, r1cs.NewCompiler, &c) b.Log("groth16", ccsBench.GetNbConstraints()) } func BenchmarkMustBeOnCurve(b *testing.B) { var c mustBeOnCurve - ccsBench, _ := frontend.Compile(ecc.BLS12_381, r1cs.NewBuilder, &c) + ccsBench, _ := frontend.Compile(ecc.BLS12_381, r1cs.NewCompiler, &c) b.Log("groth16", ccsBench.GetNbConstraints()) } func BenchmarkScalarMulGeneric(b *testing.B) { var c scalarMulGeneric - ccsBench, _ := frontend.Compile(ecc.BLS12_381, r1cs.NewBuilder, &c) + ccsBench, _ := frontend.Compile(ecc.BLS12_381, r1cs.NewCompiler, &c) b.Log("groth16", ccsBench.GetNbConstraints()) } func BenchmarkScalarMulFixed(b *testing.B) { var c scalarMulFixed - ccsBench, _ := frontend.Compile(ecc.BLS12_381, r1cs.NewBuilder, &c) + ccsBench, _ := frontend.Compile(ecc.BLS12_381, r1cs.NewCompiler, &c) b.Log("groth16", ccsBench.GetNbConstraints()) } diff --git a/std/fiat-shamir/transcript_test.go b/std/fiat-shamir/transcript_test.go index b42b7bac21..672b202731 100644 --- a/std/fiat-shamir/transcript_test.go +++ b/std/fiat-shamir/transcript_test.go @@ -156,7 +156,7 @@ func BenchmarkCompile(b *testing.B) { var ccs frontend.CompiledConstraintSystem b.ResetTimer() for i := 0; i < b.N; i++ { - ccs, _ = frontend.Compile(ecc.BN254, scs.NewBuilder, &circuit) + ccs, _ = frontend.Compile(ecc.BN254, scs.NewCompiler, &circuit) } b.Log(ccs.GetNbConstraints()) } diff --git a/std/groth16_bls12377/verifier_test.go b/std/groth16_bls12377/verifier_test.go index 75eda2035e..f90dc0974a 100644 --- a/std/groth16_bls12377/verifier_test.go +++ b/std/groth16_bls12377/verifier_test.go @@ -62,7 +62,7 @@ func generateBls12377InnerProof(t *testing.T, vk *groth16_bls12377.VerifyingKey, // create a mock cs: knowing the preimage of a hash using mimc var circuit, w mimcCircuit - r1cs, err := frontend.Compile(ecc.BLS12_377, r1cs.NewBuilder, &circuit) + r1cs, err := frontend.Compile(ecc.BLS12_377, r1cs.NewCompiler, &circuit) if err != nil { t.Fatal(err) } @@ -201,7 +201,7 @@ func BenchmarkCompile(b *testing.B) { var ccs frontend.CompiledConstraintSystem b.ResetTimer() for i := 0; i < b.N; i++ { - ccs, _ = frontend.Compile(ecc.BW6_761, r1cs.NewBuilder, &circuit) + ccs, _ = frontend.Compile(ecc.BW6_761, r1cs.NewCompiler, &circuit) } b.Log(ccs.GetNbConstraints()) } diff --git a/std/groth16_bls24315/verifier_test.go b/std/groth16_bls24315/verifier_test.go index 0f01fb94fd..dff6d1902b 100644 --- a/std/groth16_bls24315/verifier_test.go +++ b/std/groth16_bls24315/verifier_test.go @@ -63,7 +63,7 @@ func generateBls24315InnerProof(t *testing.T, vk *groth16_bls24315.VerifyingKey, // create a mock cs: knowing the preimage of a hash using mimc var circuit, w mimcCircuit - r1cs, err := frontend.Compile(ecc.BLS24_315, r1cs.NewBuilder, &circuit) + r1cs, err := frontend.Compile(ecc.BLS24_315, r1cs.NewCompiler, &circuit) if err != nil { t.Fatal(err) } @@ -202,7 +202,7 @@ func BenchmarkCompile(b *testing.B) { var ccs frontend.CompiledConstraintSystem b.ResetTimer() for i := 0; i < b.N; i++ { - ccs, _ = frontend.Compile(ecc.BW6_633, r1cs.NewBuilder, &circuit) + ccs, _ = frontend.Compile(ecc.BW6_633, r1cs.NewCompiler, &circuit) } b.Log(ccs.GetNbConstraints()) } diff --git a/test/assert.go b/test/assert.go index 2926d9ca9b..aa0f5d506e 100644 --- a/test/assert.go +++ b/test/assert.go @@ -400,13 +400,13 @@ func (assert *Assert) compile(circuit frontend.Circuit, curveID ecc.ID, backendI return ccs, nil } - var newBuilder frontend.NewBuilder + var newBuilder frontend.NewCompiler switch backendID { case backend.GROTH16: - newBuilder = r1cs.NewBuilder + newBuilder = r1cs.NewCompiler case backend.PLONK: - newBuilder = scs.NewBuilder + newBuilder = scs.NewCompiler default: panic("not implemented") } From e72a6766f3c9099533c3384326498af81b0206d2 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 25 Feb 2022 12:55:15 -0600 Subject: [PATCH 09/20] checkpoint --- frontend/compiled/symbol.go | 6 +++ frontend/compiled/variable.go | 2 +- frontend/cs/{builder.go => coeff_table.go} | 34 ++++++++--------- frontend/cs/r1cs/api.go | 44 +++++++++++----------- frontend/cs/r1cs/assertions.go | 4 +- frontend/cs/r1cs/conversion.go | 12 +++--- frontend/cs/r1cs/r1cs.go | 30 +++++++-------- frontend/cs/scs/api.go | 14 +++---- frontend/cs/scs/assertions.go | 8 ++-- frontend/cs/scs/conversion.go | 12 +++--- frontend/cs/scs/sparse_r1cs.go | 32 +++++++++------- 11 files changed, 102 insertions(+), 96 deletions(-) create mode 100644 frontend/compiled/symbol.go rename frontend/cs/{builder.go => coeff_table.go} (61%) diff --git a/frontend/compiled/symbol.go b/frontend/compiled/symbol.go new file mode 100644 index 0000000000..62675d97af --- /dev/null +++ b/frontend/compiled/symbol.go @@ -0,0 +1,6 @@ +package compiled + +type Symbol interface { + AssertIsSet() + IsConstant() bool +} diff --git a/frontend/compiled/variable.go b/frontend/compiled/variable.go index 41085ba9de..e86f8e5944 100644 --- a/frontend/compiled/variable.go +++ b/frontend/compiled/variable.go @@ -63,7 +63,7 @@ func (v Variable) AssertIsSet() { } // isConstant returns true if the variable is ONE_WIRE * coeff -func (v *Variable) IsConstant() bool { +func (v Variable) IsConstant() bool { if len(v.LinExp) != 1 { return false } diff --git a/frontend/cs/builder.go b/frontend/cs/coeff_table.go similarity index 61% rename from frontend/cs/builder.go rename to frontend/cs/coeff_table.go index 64a93832e6..dca31934af 100644 --- a/frontend/cs/builder.go +++ b/frontend/cs/coeff_table.go @@ -4,33 +4,29 @@ import ( "math/big" ) -// Builder helps build a constraint system but need not be serialized after compilation -type Builder struct { +// CoeffTable helps build a constraint system but need not be serialized after compilation +type CoeffTable struct { // Coefficients in the constraints Coeffs []big.Int // list of unique coefficients. CoeffsIDsLarge map[string]int // map to check existence of a coefficient (key = coeff.Bytes()) CoeffsIDsInt64 map[int64]int // map to check existence of a coefficient (key = int64 value) - - // map for recording boolean constrained variables (to not constrain them twice) - MTBooleans map[int]struct{} } -func NewBuilder() Builder { - return Builder{ +func NewCoeffTable() CoeffTable { + return CoeffTable{ Coeffs: make([]big.Int, 4), CoeffsIDsLarge: make(map[string]int), CoeffsIDsInt64: make(map[int64]int, 4), - MTBooleans: make(map[int]struct{}), } } // CoeffID tries to fetch the entry where b is if it exits, otherwise appends b to // the list of Coeffs and returns the corresponding entry -func (b *Builder) CoeffID(v *big.Int) int { +func (t *CoeffTable) CoeffID(v *big.Int) int { // if the coeff is a int64 we have a fast path. if v.IsInt64() { - return b.coeffID64(v.Int64()) + return t.coeffID64(v.Int64()) } // GobEncode is 3x faster than b.Text(16). Slightly slower than Bytes, but Bytes return the same @@ -39,28 +35,28 @@ func (b *Builder) CoeffID(v *big.Int) int { key := string(bKey) // if the coeff is already stored, fetch its ID from the cs.CoeffsIDs map - if idx, ok := b.CoeffsIDsLarge[key]; ok { + if idx, ok := t.CoeffsIDsLarge[key]; ok { return idx } // else add it in the cs.Coeffs map and update the cs.CoeffsIDs map var bCopy big.Int bCopy.Set(v) - resID := len(b.Coeffs) - b.Coeffs = append(b.Coeffs, bCopy) - b.CoeffsIDsLarge[key] = resID + resID := len(t.Coeffs) + t.Coeffs = append(t.Coeffs, bCopy) + t.CoeffsIDsLarge[key] = resID return resID } -func (b *Builder) coeffID64(v int64) int { - if resID, ok := b.CoeffsIDsInt64[v]; ok { +func (t *CoeffTable) coeffID64(v int64) int { + if resID, ok := t.CoeffsIDsInt64[v]; ok { return resID } else { var bCopy big.Int bCopy.SetInt64(v) - resID := len(b.Coeffs) - b.Coeffs = append(b.Coeffs, bCopy) - b.CoeffsIDsInt64[v] = resID + resID := len(t.Coeffs) + t.Coeffs = append(t.Coeffs, bCopy) + t.CoeffsIDsInt64[v] = resID return resID } } diff --git a/frontend/cs/r1cs/api.go b/frontend/cs/r1cs/api.go index 4045e694ce..70e444d393 100644 --- a/frontend/cs/r1cs/api.go +++ b/frontend/cs/r1cs/api.go @@ -40,7 +40,7 @@ import ( func (system *r1CS) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { // extract frontend.Variables from input - vars, s := system.toVariables(append([]frontend.Variable{i1, i2}, in...)...) + vars, s := system.ToSymbols(append([]frontend.Variable{i1, i2}, in...)...) // allocate resulting frontend.Variable t := false @@ -58,7 +58,7 @@ func (system *r1CS) Add(i1, i2 frontend.Variable, in ...frontend.Variable) front // Neg returns -i func (system *r1CS) Neg(i frontend.Variable) frontend.Variable { - vars, _ := system.toVariables(i) + vars, _ := system.ToSymbols(i) if vars[0].IsConstant() { n := system.constantValue(vars[0]) @@ -76,7 +76,7 @@ func (system *r1CS) Neg(i frontend.Variable) frontend.Variable { func (system *r1CS) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { // extract frontend.Variables from input - vars, s := system.toVariables(append([]frontend.Variable{i1, i2}, in...)...) + vars, s := system.ToSymbols(append([]frontend.Variable{i1, i2}, in...)...) // allocate resulting frontend.Variable t := false @@ -100,7 +100,7 @@ func (system *r1CS) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) front // Mul returns res = i1 * i2 * ... in func (system *r1CS) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - vars, _ := system.toVariables(append([]frontend.Variable{i1, i2}, in...)...) + vars, _ := system.ToSymbols(append([]frontend.Variable{i1, i2}, in...)...) mul := func(v1, v2 compiled.Variable) compiled.Variable { @@ -156,10 +156,10 @@ func (system *r1CS) mulConstant(v1, constant compiled.Variable) compiled.Variabl case compiled.CoeffIdTwo: newCoeff.Add(lambda, lambda) default: - coeff := system.builder.Coeffs[cID] + coeff := system.st.Coeffs[cID] newCoeff.Mul(&coeff, lambda) } - res.LinExp[i] = compiled.Pack(vID, system.builder.CoeffID(&newCoeff), visibility) + res.LinExp[i] = compiled.Pack(vID, system.st.CoeffID(&newCoeff), visibility) } t := false res.IsBoolean = &t @@ -167,7 +167,7 @@ func (system *r1CS) mulConstant(v1, constant compiled.Variable) compiled.Variabl } func (system *r1CS) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable { - vars, _ := system.toVariables(i1, i2) + vars, _ := system.ToSymbols(i1, i2) v1 := vars[0] v2 := vars[1] @@ -199,7 +199,7 @@ func (system *r1CS) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable { // Div returns res = i1 / i2 func (system *r1CS) Div(i1, i2 frontend.Variable) frontend.Variable { - vars, _ := system.toVariables(i1, i2) + vars, _ := system.ToSymbols(i1, i2) v1 := vars[0] v2 := vars[1] @@ -233,7 +233,7 @@ func (system *r1CS) Div(i1, i2 frontend.Variable) frontend.Variable { // Inverse returns res = inverse(v) func (system *r1CS) Inverse(i1 frontend.Variable) frontend.Variable { - vars, _ := system.toVariables(i1) + vars, _ := system.ToSymbols(i1) if vars[0].IsConstant() { // c := vars[0].constantValue(cs) @@ -274,7 +274,7 @@ func (system *r1CS) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable } } - vars, _ := system.toVariables(i1) + vars, _ := system.ToSymbols(i1) a := vars[0] // if a is a constant, work with the big int value. @@ -345,7 +345,7 @@ func toSliceOfVariables(v []compiled.Variable) []frontend.Variable { // FromBinary packs b, seen as a fr.Element in little endian func (system *r1CS) FromBinary(_b ...frontend.Variable) frontend.Variable { - b, _ := system.toVariables(_b...) + b, _ := system.ToSymbols(_b...) // ensure inputs are set for i := 0; i < len(b); i++ { @@ -374,7 +374,7 @@ func (system *r1CS) FromBinary(_b ...frontend.Variable) frontend.Variable { // Xor compute the XOR between two frontend.Variables func (system *r1CS) Xor(_a, _b frontend.Variable) frontend.Variable { - vars, _ := system.toVariables(_a, _b) + vars, _ := system.ToSymbols(_a, _b) a := vars[0] b := vars[1] @@ -398,7 +398,7 @@ func (system *r1CS) Xor(_a, _b frontend.Variable) frontend.Variable { // Or compute the OR between two frontend.Variables func (system *r1CS) Or(_a, _b frontend.Variable) frontend.Variable { - vars, _ := system.toVariables(_a, _b) + vars, _ := system.ToSymbols(_a, _b) a := vars[0] b := vars[1] @@ -421,7 +421,7 @@ func (system *r1CS) Or(_a, _b frontend.Variable) frontend.Variable { // And compute the AND between two frontend.Variables func (system *r1CS) And(_a, _b frontend.Variable) frontend.Variable { - vars, _ := system.toVariables(_a, _b) + vars, _ := system.ToSymbols(_a, _b) a := vars[0] b := vars[1] @@ -440,7 +440,7 @@ func (system *r1CS) And(_a, _b frontend.Variable) frontend.Variable { // Select if i0 is true, yields i1 else yields i2 func (system *r1CS) Select(i0, i1, i2 frontend.Variable) frontend.Variable { - vars, _ := system.toVariables(i0, i1, i2) + vars, _ := system.ToSymbols(i0, i1, i2) b := vars[0] // ensures that b is boolean @@ -474,7 +474,7 @@ func (system *r1CS) Select(i0, i1, i2 frontend.Variable) frontend.Variable { // and b1. Returns i0 if b0=b1=0, i1 if b0=1 and b1=0, i2 if b0=0 and b1=1 // and i3 if b0=b1=1. func (system *r1CS) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 frontend.Variable) frontend.Variable { - vars, _ := system.toVariables(b0, b1, i0, i1, i2, i3) + vars, _ := system.ToSymbols(b0, b1, i0, i1, i2, i3) s0, s1 := vars[0], vars[1] in0, in1, in2, in3 := vars[2], vars[3], vars[4], vars[5] @@ -505,7 +505,7 @@ func (system *r1CS) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 frontend.Va // IsZero returns 1 if i1 is zero, 0 otherwise func (system *r1CS) IsZero(i1 frontend.Variable) frontend.Variable { - vars, _ := system.toVariables(i1) + vars, _ := system.ToSymbols(i1) a := vars[0] if a.IsConstant() { // c := a.constantValue(cs) @@ -540,7 +540,7 @@ func (system *r1CS) IsZero(i1 frontend.Variable) frontend.Variable { // Cmp returns 1 if i1>i2, 0 if i1=i2, -1 if i1 Date: Mon, 28 Feb 2022 12:14:26 -0600 Subject: [PATCH 10/20] refactor: remove IsBoolean from R1CS variables --- frontend/compile.go | 10 +++++ frontend/compiled/term.go | 10 +++++ frontend/compiled/variable.go | 4 +- frontend/cs/r1cs/api.go | 71 ++++++++++++-------------------- frontend/cs/r1cs/assertions.go | 30 +++++++------- frontend/cs/r1cs/r1cs.go | 74 ++++++++++++++++++++++++---------- frontend/cs/scs/assertions.go | 6 +-- frontend/cs/scs/sparse_r1cs.go | 23 ++++++++--- 8 files changed, 134 insertions(+), 94 deletions(-) diff --git a/frontend/compile.go b/frontend/compile.go index 03d6434083..f1111424f0 100644 --- a/frontend/compile.go +++ b/frontend/compile.go @@ -28,6 +28,16 @@ type Compiler interface { // AddSecretVariable is called by the compiler when parsing the circuit schema. It panics if // called inside circuit.Define() AddSecretVariable(name string) Variable + + // MarkBoolean sets (but do not constraint!) v to be boolean + // This is useful in scenarios where a variable is known to be boolean through a constraint + // that is not api.AssertIsBoolean. If v is a constant, this is a no-op. + MarkBoolean(v Variable) + + // IsBoolean returns true if given variable was marked as boolean in the compiler (see MarkBoolean) + // Use with care; variable may not have been **constrained** to be boolean + // This returns true if the v is a constant and v == 0 || v == 1. + IsBoolean(v Variable) bool } type NewCompiler func(ecc.ID) (Compiler, error) diff --git a/frontend/compiled/term.go b/frontend/compiled/term.go index 3d7b2d7380..1a84217e15 100644 --- a/frontend/compiled/term.go +++ b/frontend/compiled/term.go @@ -217,3 +217,13 @@ func (v LinearExpression) Less(i, j int) bool { } return iVis > jVis } + +// HashCode returns a fast-to-compute but NOT collision resistant hash code identifier for the linear +// expression +func (v LinearExpression) HashCode() uint64 { + h := uint64(17) + for _, val := range v { + h = h*23 + uint64(val) + } + return h +} diff --git a/frontend/compiled/variable.go b/frontend/compiled/variable.go index e86f8e5944..cd284193bb 100644 --- a/frontend/compiled/variable.go +++ b/frontend/compiled/variable.go @@ -27,14 +27,12 @@ var errNoValue = errors.New("can't determine API input value") // Variable represent a linear expression of wires type Variable struct { - LinExp LinearExpression - IsBoolean *bool + LinExp LinearExpression } // Clone returns a copy of the underlying slice func (v Variable) Clone() Variable { var res Variable - res.IsBoolean = v.IsBoolean res.LinExp = make([]Term, len(v.LinExp)) copy(res.LinExp, v.LinExp) return res diff --git a/frontend/cs/r1cs/api.go b/frontend/cs/r1cs/api.go index 70e444d393..5035c2484e 100644 --- a/frontend/cs/r1cs/api.go +++ b/frontend/cs/r1cs/api.go @@ -43,8 +43,7 @@ func (system *r1CS) Add(i1, i2 frontend.Variable, in ...frontend.Variable) front vars, s := system.ToSymbols(append([]frontend.Variable{i1, i2}, in...)...) // allocate resulting frontend.Variable - t := false - res := compiled.Variable{LinExp: make([]compiled.Term, 0, s), IsBoolean: &t} + res := compiled.Variable{LinExp: make([]compiled.Term, 0, s)} for _, v := range vars { l := v.Clone() @@ -63,11 +62,10 @@ func (system *r1CS) Neg(i frontend.Variable) frontend.Variable { if vars[0].IsConstant() { n := system.constantValue(vars[0]) n.Neg(n) - return system.constant(n) + return system.ToVariable(n) } - // ok to pass pointer since if i is boolean constrained later, so must be res - res := compiled.Variable{LinExp: system.negateLinExp(vars[0].LinExp), IsBoolean: vars[0].IsBoolean} + res := compiled.Variable{LinExp: system.negateLinExp(vars[0].LinExp)} return res } @@ -79,10 +77,8 @@ func (system *r1CS) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) front vars, s := system.ToSymbols(append([]frontend.Variable{i1, i2}, in...)...) // allocate resulting frontend.Variable - t := false res := compiled.Variable{ - LinExp: make([]compiled.Term, 0, s), - IsBoolean: &t, + LinExp: make([]compiled.Term, 0, s), } c := vars[0].Clone() @@ -117,7 +113,7 @@ func (system *r1CS) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front b2 := system.constantValue(v2) b1.Mul(b1, b2).Mod(b1, system.CurveID.Info().Fr.Modulus()) - return system.constant(b1).(compiled.Variable) + return system.ToVariable(b1).(compiled.Variable) } // ensure v2 is the constant @@ -161,8 +157,6 @@ func (system *r1CS) mulConstant(v1, constant compiled.Variable) compiled.Variabl } res.LinExp[i] = compiled.Pack(vID, system.st.CoeffID(&newCoeff), visibility) } - t := false - res.IsBoolean = &t return res } @@ -190,11 +184,11 @@ func (system *r1CS) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable { if v1.IsConstant() { b2.Mul(b2, system.constantValue(v1)).Mod(b2, q) - return system.constant(b2) + return system.ToVariable(b2) } // v1 is not constant - return system.mulConstant(v1, system.constant(b2).(compiled.Variable)) + return system.mulConstant(v1, system.ToVariable(b2).(compiled.Variable)) } // Div returns res = i1 / i2 @@ -224,11 +218,11 @@ func (system *r1CS) Div(i1, i2 frontend.Variable) frontend.Variable { if v1.IsConstant() { b2.Mul(b2, system.constantValue(v1)).Mod(b2, q) - return system.constant(b2) + return system.ToVariable(b2) } // v1 is not constant - return system.mulConstant(v1, system.constant(b2).(compiled.Variable)) + return system.mulConstant(v1, system.ToVariable(b2).(compiled.Variable)) } // Inverse returns res = inverse(v) @@ -243,7 +237,7 @@ func (system *r1CS) Inverse(i1 frontend.Variable) frontend.Variable { } c.ModInverse(c, system.CurveID.Info().Fr.Modulus()) - return system.constant(c) + return system.ToVariable(c) } // allocate resulting frontend.Variable @@ -280,11 +274,11 @@ func (system *r1CS) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable // if a is a constant, work with the big int value. if a.IsConstant() { c := system.constantValue(a) - b := make([]compiled.Variable, nbBits) + b := make([]frontend.Variable, nbBits) for i := 0; i < len(b); i++ { - b[i] = system.constant(c.Bit(i)).(compiled.Variable) + b[i] = system.ToVariable(c.Bit(i)) } - return toSliceOfVariables(b) + return b } return system.toBinary(a, nbBits, false) @@ -334,15 +328,6 @@ func (system *r1CS) toBinary(a compiled.Variable, nbBits int, unsafe bool) []fro } -func toSliceOfVariables(v []compiled.Variable) []frontend.Variable { - // TODO this is ugly. - r := make([]frontend.Variable, len(v)) - for i := 0; i < len(v); i++ { - r[i] = v[i] - } - return r -} - // FromBinary packs b, seen as a fr.Element in little endian func (system *r1CS) FromBinary(_b ...frontend.Variable) frontend.Variable { b, _ := system.ToSymbols(_b...) @@ -355,7 +340,7 @@ func (system *r1CS) FromBinary(_b ...frontend.Variable) frontend.Variable { // res = Σ (2**i * b[i]) var res, v frontend.Variable - res = system.constant(0) // no constraint is recorded + res = system.ToVariable(0) // no constraint is recorded var c big.Int c.SetUint64(1) @@ -384,11 +369,8 @@ func (system *r1CS) Xor(_a, _b frontend.Variable) frontend.Variable { // the formulation used is for easing up the conversion to sparse r1cs res := system.newInternalVariable() - res.IsBoolean = new(bool) - *res.IsBoolean = true + system.MarkBoolean(res) c := system.Neg(res).(compiled.Variable) - c.IsBoolean = new(bool) - *c.IsBoolean = false c.LinExp = append(c.LinExp, a.LinExp[0], b.LinExp[0]) aa := system.Mul(a, 2) system.Constraints = append(system.Constraints, newR1C(aa, b, c)) @@ -408,11 +390,8 @@ func (system *r1CS) Or(_a, _b frontend.Variable) frontend.Variable { // the formulation used is for easing up the conversion to sparse r1cs res := system.newInternalVariable() - res.IsBoolean = new(bool) - *res.IsBoolean = true + system.MarkBoolean(res) c := system.Neg(res).(compiled.Variable) - c.IsBoolean = new(bool) - *c.IsBoolean = false c.LinExp = append(c.LinExp, a.LinExp[0], b.LinExp[0]) system.Constraints = append(system.Constraints, newR1C(a, b, c)) @@ -511,9 +490,9 @@ func (system *r1CS) IsZero(i1 frontend.Variable) frontend.Variable { // c := a.constantValue(cs) c := system.constantValue(a) if c.IsUint64() && c.Uint64() == 0 { - return system.constant(1) + return system.ToVariable(1) } - return system.constant(0) + return system.ToVariable(0) } debug := system.AddDebugInfo("isZero", a) @@ -529,7 +508,7 @@ func (system *r1CS) IsZero(i1 frontend.Variable) frontend.Variable { panic(err) } m := res[0] - system.addConstraint(newR1C(a, m, system.constant(0)), debug) + system.addConstraint(newR1C(a, m, system.ToVariable(0)), debug) system.AssertIsBoolean(m) ma := system.Add(m, a) @@ -544,7 +523,7 @@ func (system *r1CS) Cmp(i1, i2 frontend.Variable) frontend.Variable { bi1 := system.ToBinary(vars[0], system.BitLen()) bi2 := system.ToBinary(vars[1], system.BitLen()) - res := system.constant(0) + res := system.ToVariable(0) for i := system.BitLen() - 1; i >= 0; i-- { @@ -749,14 +728,14 @@ func (system *r1CS) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.V return res, nil } -// constant will return (and allocate if neccesary) a frontend.Variable from given value +// ToVariable will return (and allocate if neccesary) a frontend.Variable from given value // // if input is already a frontend.Variable, does nothing -// else, attempts to convert input to a big.Int (see utils.FromInterface) and returns a constant frontend.Variable +// else, attempts to convert input to a big.Int (see utils.FromInterface) and returns a ToVariable frontend.Variable // -// a constant frontend.Variable does NOT necessary allocate a frontend.Variable in the ConstraintSystem +// a ToVariable frontend.Variable does NOT necessary allocate a frontend.Variable in the ConstraintSystem // it is in the form ONE_WIRE * coeff -func (system *r1CS) constant(input frontend.Variable) frontend.Variable { +func (system *r1CS) ToVariable(input interface{}) frontend.Variable { switch t := input.(type) { case compiled.Variable: @@ -778,7 +757,7 @@ func (system *r1CS) ToSymbols(in ...frontend.Variable) ([]compiled.Variable, int r := make([]compiled.Variable, 0, len(in)) s := 0 e := func(i frontend.Variable) { - v := system.constant(i).(compiled.Variable) + v := system.ToVariable(i).(compiled.Variable) r = append(r, v) s += len(v.LinExp) } diff --git a/frontend/cs/r1cs/assertions.go b/frontend/cs/r1cs/assertions.go index 5790c99529..c77e0e9016 100644 --- a/frontend/cs/r1cs/assertions.go +++ b/frontend/cs/r1cs/assertions.go @@ -28,8 +28,8 @@ import ( // AssertIsEqual adds an assertion in the constraint system (i1 == i2) func (system *r1CS) AssertIsEqual(i1, i2 frontend.Variable) { // encoded 1 * i1 == i2 - r := system.constant(i1).(compiled.Variable) - o := system.constant(i2).(compiled.Variable) + r := system.ToVariable(i1).(compiled.Variable) + o := system.ToVariable(i2).(compiled.Variable) debug := system.AddDebugInfo("assertIsEqual", r, " == ", o) @@ -47,11 +47,6 @@ func (system *r1CS) AssertIsBoolean(i1 frontend.Variable) { vars, _ := system.ToSymbols(i1) v := vars[0] - if *v.IsBoolean { - return // compiled.Variable is already constrained - } - *v.IsBoolean = true - if v.IsConstant() { c := system.constantValue(v) if !(c.IsUint64() && (c.Uint64() == 0 || c.Uint64() == 1)) { @@ -60,9 +55,14 @@ func (system *r1CS) AssertIsBoolean(i1 frontend.Variable) { return } + if system.IsBoolean(v) { + return // compiled.Variable is already constrained + } + system.MarkBoolean(v) + debug := system.AddDebugInfo("assertIsBoolean", v, " == (0|1)") - o := system.constant(0) + o := system.ToVariable(0) // ensure v * (1 - v) == 0 _v := system.Sub(1, v) @@ -97,9 +97,9 @@ func (system *r1CS) mustBeLessOrEqVar(a, bound compiled.Variable) { boundBits := system.ToBinary(bound, nbBits) p := make([]frontend.Variable, nbBits+1) - p[nbBits] = system.constant(1) + p[nbBits] = system.ToVariable(1) - zero := system.constant(0) + zero := system.ToVariable(0) for i := nbBits - 1; i >= 0; i-- { @@ -122,7 +122,7 @@ func (system *r1CS) mustBeLessOrEqVar(a, bound compiled.Variable) { // note if bound[i] == 1, this constraint is (1 - ai) * ai == 0 // → this is a boolean constraint // if bound[i] == 0, t must be 0 or 1, thus ai must be 0 or 1 too - system.markBoolean(aBits[i].(compiled.Variable)) // this does not create a constraint + system.MarkBoolean(aBits[i].(compiled.Variable)) // this does not create a constraint system.addConstraint(newR1C(l, aBits[i], zero), debug) } @@ -142,7 +142,7 @@ func (system *r1CS) mustBeLessOrEqCst(a compiled.Variable, bound big.Int) { } // debug info - debug := system.AddDebugInfo("mustBeLessOrEq", a, " <= ", system.constant(bound)) + debug := system.AddDebugInfo("mustBeLessOrEq", a, " <= ", system.ToVariable(bound)) // note that at this stage, we didn't boolean-constraint these new variables yet // (as opposed to ToBinary) @@ -159,7 +159,7 @@ func (system *r1CS) mustBeLessOrEqCst(a compiled.Variable, bound big.Int) { p := make([]frontend.Variable, nbBits+1) // p[i] == 1 → a[j] == c[j] for all j ⩾ i - p[nbBits] = system.constant(1) + p[nbBits] = system.ToVariable(1) for i := nbBits - 1; i >= t; i-- { if bound.Bit(i) == 0 { @@ -175,8 +175,8 @@ func (system *r1CS) mustBeLessOrEqCst(a compiled.Variable, bound big.Int) { l := system.Sub(1, p[i+1]) l = system.Sub(l, aBits[i]) - system.addConstraint(newR1C(l, aBits[i], system.constant(0)), debug) - system.markBoolean(aBits[i].(compiled.Variable)) + system.addConstraint(newR1C(l, aBits[i], system.ToVariable(0)), debug) + system.MarkBoolean(aBits[i].(compiled.Variable)) } else { system.AssertIsBoolean(aBits[i]) } diff --git a/frontend/cs/r1cs/r1cs.go b/frontend/cs/r1cs/r1cs.go index dc76d6ff9d..29ae8a1f65 100644 --- a/frontend/cs/r1cs/r1cs.go +++ b/frontend/cs/r1cs/r1cs.go @@ -40,6 +40,9 @@ type r1CS struct { Constraints []compiled.R1C st cs.CoeffTable + + // map for recording boolean constrained variables (to not constrain them twice) + mtBooleans map[uint64][]compiled.LinearExpression } // initialCapacity has quite some impact on frontend performance, especially on large circuits size @@ -57,6 +60,7 @@ func newR1CS(curveID ecc.ID, initialCapacity ...int) *r1CS { }, Constraints: make([]compiled.R1C, 0, capacity), st: cs.NewCoeffTable(), + mtBooleans: make(map[uint64][]compiled.LinearExpression), } system.st.Coeffs[compiled.CoeffIdZero].SetInt64(0) @@ -83,12 +87,10 @@ func newR1CS(curveID ecc.ID, initialCapacity ...int) *r1CS { // newInternalVariable creates a new wire, appends it on the list of wires of the circuit, sets // the wire's id to the number of wires, and returns it func (system *r1CS) newInternalVariable() compiled.Variable { - t := false idx := system.NbInternalVariables system.NbInternalVariables++ return compiled.Variable{ - LinExp: compiled.LinearExpression{compiled.Pack(idx, compiled.CoeffIdOne, schema.Internal)}, - IsBoolean: &t, + LinExp: compiled.LinearExpression{compiled.Pack(idx, compiled.CoeffIdOne, schema.Internal)}, } } @@ -97,12 +99,10 @@ func (system *r1CS) AddPublicVariable(name string) frontend.Variable { if system.Schema != nil { panic("do not call AddPublicVariable in circuit.Define()") } - t := false idx := len(system.Public) system.Public = append(system.Public, name) res := compiled.Variable{ - LinExp: compiled.LinearExpression{compiled.Pack(idx, compiled.CoeffIdOne, schema.Public)}, - IsBoolean: &t, + LinExp: compiled.LinearExpression{compiled.Pack(idx, compiled.CoeffIdOne, schema.Public)}, } return res } @@ -112,31 +112,27 @@ func (system *r1CS) AddSecretVariable(name string) frontend.Variable { if system.Schema != nil { panic("do not call AddSecretVariable in circuit.Define()") } - t := false idx := len(system.Secret) system.Secret = append(system.Secret, name) res := compiled.Variable{ - LinExp: compiled.LinearExpression{compiled.Pack(idx, compiled.CoeffIdOne, schema.Secret)}, - IsBoolean: &t, + LinExp: compiled.LinearExpression{compiled.Pack(idx, compiled.CoeffIdOne, schema.Secret)}, } return res } // func (v *variable) constantValue(system *R1CS) *big.Int { -func (system *r1CS) constantValue(v compiled.Symbol) *big.Int { +func (system *r1CS) constantValue(v compiled.Variable) *big.Int { // TODO this might be a good place to start hunting useless allocations. // maybe through a big.Int pool. if !v.IsConstant() { panic("can't get big.Int value on a non-constant variable") } - return new(big.Int).Set(&system.st.Coeffs[v.(compiled.Variable).LinExp[0].CoeffID()]) + return new(big.Int).Set(&system.st.Coeffs[v.LinExp[0].CoeffID()]) } func (system *r1CS) one() compiled.Variable { - t := false return compiled.Variable{ - LinExp: compiled.LinearExpression{compiled.Pack(0, compiled.CoeffIdOne, schema.Public)}, - IsBoolean: &t, + LinExp: compiled.LinearExpression{compiled.Pack(0, compiled.CoeffIdOne, schema.Public)}, } } @@ -201,15 +197,51 @@ func (system *r1CS) setCoeff(v compiled.Term, coeff *big.Int) compiled.Term { return compiled.Pack(vID, system.st.CoeffID(coeff), vVis) } -// markBoolean marks the Variable as boolean and return true -// if a constraint was added, false if the Variable was already -// constrained as a boolean -func (system *r1CS) markBoolean(v compiled.Variable) bool { - if *v.IsBoolean { +// MarkBoolean sets (but do not **constraint**!) v to be boolean +// This is useful in scenarios where a variable is known to be boolean through a constraint +// that is not api.AssertIsBoolean. If v is a constant, this is a no-op. +func (system *r1CS) MarkBoolean(v frontend.Variable) { + if system.IsConstant(v) { + return + } + // v is a linear expression + l := v.(compiled.Variable).LinExp + if !sort.IsSorted(l) { + sort.Sort(l) + } + + key := l.HashCode() + list := system.mtBooleans[key] + list = append(list, l) + system.mtBooleans[key] = list +} + +// IsBoolean returns true if given variable was marked as boolean in the compiler (see MarkBoolean) +// Use with care; variable may not have been **constrained** to be boolean +// This returns true if the v is a constant and v == 0 || v == 1. +func (system *r1CS) IsBoolean(v frontend.Variable) bool { + if system.IsConstant(v) { + b := system.ConstantValue(v) + return b.IsUint64() && b.Uint64() <= 1 + } + // v is a linear expression + l := v.(compiled.Variable).LinExp + if !sort.IsSorted(l) { + sort.Sort(l) + } + + key := l.HashCode() + list, ok := system.mtBooleans[key] + if !ok { return false } - *v.IsBoolean = true - return true + + for _, v := range list { + if v.Equal(l) { + return true + } + } + return false } // checkVariables perform post compilation checks on the Variables diff --git a/frontend/cs/scs/assertions.go b/frontend/cs/scs/assertions.go index e1029ca1a2..1e27d65937 100644 --- a/frontend/cs/scs/assertions.go +++ b/frontend/cs/scs/assertions.go @@ -73,10 +73,10 @@ func (system *sparseR1CS) AssertIsBoolean(i1 frontend.Variable) { return } t := i1.(compiled.Term) - if system.isBoolean(t) { + if system.IsBoolean(t) { return } - system.markBoolean(t) + system.MarkBoolean(t) system.mtBooleans[int(t)] = struct{}{} debug := system.AddDebugInfo("assertIsBoolean", t, " == (0|1)") cID, _, _ := t.Unpack() @@ -127,7 +127,7 @@ func (system *sparseR1CS) mustBeLessOrEqVar(a compiled.Term, bound compiled.Term // note if bound[i] == 1, this constraint is (1 - ai) * ai == 0 // → this is a boolean constraint // if bound[i] == 0, t must be 0 or 1, thus ai must be 0 or 1 too - system.markBoolean(aBits[i].(compiled.Term)) // this does not create a constraint + system.MarkBoolean(aBits[i].(compiled.Term)) // this does not create a constraint system.addPlonkConstraint( l.(compiled.Term), diff --git a/frontend/cs/scs/sparse_r1cs.go b/frontend/cs/scs/sparse_r1cs.go index c0cf9afb77..24a2950cd3 100644 --- a/frontend/cs/scs/sparse_r1cs.go +++ b/frontend/cs/scs/sparse_r1cs.go @@ -165,15 +165,26 @@ func (system *sparseR1CS) zero() compiled.Term { return a } -// returns true if a variable is already boolean -func (system *sparseR1CS) isBoolean(t compiled.Term) bool { - _, ok := system.mtBooleans[int(t)] +// IsBoolean returns true if given variable was marked as boolean in the compiler (see MarkBoolean) +// Use with care; variable may not have been **constrained** to be boolean +// This returns true if the v is a constant and v == 0 || v == 1. +func (system *sparseR1CS) IsBoolean(v frontend.Variable) bool { + if system.IsConstant(v) { + b := system.ConstantValue(v) + return b.IsUint64() && b.Uint64() <= 1 + } + _, ok := system.mtBooleans[int(v.(compiled.Term))] return ok } -// markBoolean records t in the map to not boolean constrain it twice -func (system *sparseR1CS) markBoolean(t compiled.Term) { - system.mtBooleans[int(t)] = struct{}{} +// MarkBoolean sets (but do not constraint!) v to be boolean +// This is useful in scenarios where a variable is known to be boolean through a constraint +// that is not api.AssertIsBoolean. If v is a constant, this is a no-op. +func (system *sparseR1CS) MarkBoolean(v frontend.Variable) { + if system.IsConstant(v) { + return + } + system.mtBooleans[int(v.(compiled.Term))] = struct{}{} } // checkVariables perform post compilation checks on the Variables From 9a1ae2dc232608b734e8b7fdb8c0e85f0a92084d Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 28 Feb 2022 13:15:38 -0600 Subject: [PATCH 11/20] refactor: split compiler, api and builder interface into interfaces --- examples/rollup/circuit.go | 2 +- frontend/api.go | 47 +------- frontend/compile.go | 59 ++++++++-- frontend/compiled/variable.go | 11 -- frontend/cs/r1cs/api.go | 109 +++++++++--------- frontend/cs/r1cs/assertions.go | 3 +- frontend/cs/r1cs/conversion.go | 3 + frontend/cs/r1cs/r1cs.go | 17 +-- frontend/cs/scs/api.go | 24 ++-- frontend/cs/scs/conversion.go | 3 + frontend/cs/scs/sparse_r1cs.go | 5 +- internal/backend/circuits/hint.go | 6 +- std/algebra/sw_bls12377/g1.go | 10 +- std/algebra/sw_bls24315/g1.go | 10 +- .../twistededwards/bandersnatch/point_test.go | 12 +- std/algebra/twistededwards/point_test.go | 14 +-- std/fiat-shamir/transcript_test.go | 2 +- std/hash/mimc/mimc.go | 2 +- std/signature/eddsa/eddsa_test.go | 2 +- test/engine.go | 19 ++- test/engine_test.go | 8 +- 21 files changed, 184 insertions(+), 184 deletions(-) diff --git a/examples/rollup/circuit.go b/examples/rollup/circuit.go index 83f554c2b9..8d15a115ae 100644 --- a/examples/rollup/circuit.go +++ b/examples/rollup/circuit.go @@ -88,7 +88,7 @@ type TransferConstraints struct { func (circuit *Circuit) postInit(api frontend.API) error { // edward curve params - params, err := twistededwards.NewEdCurve(api.Curve()) + params, err := twistededwards.NewEdCurve(api.Compiler().Curve()) if err != nil { return err } diff --git a/frontend/api.go b/frontend/api.go index d3b7f5e09e..7a184f4678 100644 --- a/frontend/api.go +++ b/frontend/api.go @@ -16,14 +16,6 @@ limitations under the License. package frontend -import ( - "math/big" - - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/hint" -) - // API represents the available functions to circuit developers type API interface { // --------------------------------------------------------------------------------------------- @@ -112,41 +104,6 @@ type API interface { // whose value will be resolved at runtime when computed by the solver Println(a ...Variable) - // NewHint initializes internal variables whose value will be evaluated - // using the provided hint function at run time from the inputs. Inputs must - // be either variables or convertible to *big.Int. The function returns an - // error if the number of inputs is not compatible with f. - // - // The hint function is provided at the proof creation time and is not - // embedded into the circuit. From the backend point of view, the variable - // returned by the hint function is equivalent to the user-supplied witness, - // but its actual value is assigned by the solver, not the caller. - // - // No new constraints are added to the newly created wire and must be added - // manually in the circuit. Failing to do so leads to solver failure. - // - // If nbOutputs is specified, it must be >= 1 and <= f.NbOutputs - NewHint(f hint.Function, nbOutputs int, inputs ...Variable) ([]Variable, error) - - // Tag creates a tag at a given place in a circuit. The state of the tag may contain informations needed to - // measure constraints, variables and coefficients creations through AddCounter - Tag(name string) Tag - - // AddCounter measures the number of constraints, variables and coefficients created between two tags - // note that the PlonK statistics are contextual since there is a post-compile phase where linear expressions - // are factorized. That is, measuring 2 times the "repeating" piece of circuit may give less constraints the second time - AddCounter(from, to Tag) - - // IsConstant returns true if v is a constant known at compile time - IsConstant(v Variable) bool - - // ConstantValue returns the big.Int value of v. It - // panics if v.IsConstant() == false - ConstantValue(v Variable) *big.Int - - // CurveID returns the ecc.ID injected by the compiler - Curve() ecc.ID - - // Backend returns the backend.ID injected by the compiler - Backend() backend.ID + // Compiler returns the compiler object for advanced circuit development + Compiler() Compiler } diff --git a/frontend/compile.go b/frontend/compile.go index f1111424f0..aad3cadb74 100644 --- a/frontend/compile.go +++ b/frontend/compile.go @@ -3,17 +3,19 @@ package frontend import ( "errors" "fmt" + "math/big" "reflect" "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/debug" "github.com/consensys/gnark/frontend/schema" ) -// Compiler represents a constraint system compiler -type Compiler interface { - // a compiler must implement frontend.API and will be injected in circuit.Define() +type Builder interface { API + Compiler // Compile is called after circuit.Define() to produce a final IR (CompiledConstraintSystem) Compile(opt CompileConfig) (CompiledConstraintSystem, error) @@ -28,7 +30,10 @@ type Compiler interface { // AddSecretVariable is called by the compiler when parsing the circuit schema. It panics if // called inside circuit.Define() AddSecretVariable(name string) Variable +} +// Compiler represents a constraint system compiler +type Compiler interface { // MarkBoolean sets (but do not constraint!) v to be boolean // This is useful in scenarios where a variable is known to be boolean through a constraint // that is not api.AssertIsBoolean. If v is a constant, this is a no-op. @@ -38,9 +43,45 @@ type Compiler interface { // Use with care; variable may not have been **constrained** to be boolean // This returns true if the v is a constant and v == 0 || v == 1. IsBoolean(v Variable) bool + + // NewHint initializes internal variables whose value will be evaluated + // using the provided hint function at run time from the inputs. Inputs must + // be either variables or convertible to *big.Int. The function returns an + // error if the number of inputs is not compatible with f. + // + // The hint function is provided at the proof creation time and is not + // embedded into the circuit. From the backend point of view, the variable + // returned by the hint function is equivalent to the user-supplied witness, + // but its actual value is assigned by the solver, not the caller. + // + // No new constraints are added to the newly created wire and must be added + // manually in the circuit. Failing to do so leads to solver failure. + // + // If nbOutputs is specified, it must be >= 1 and <= f.NbOutputs + NewHint(f hint.Function, nbOutputs int, inputs ...Variable) ([]Variable, error) + + // Tag creates a tag at a given place in a circuit. The state of the tag may contain informations needed to + // measure constraints, variables and coefficients creations through AddCounter + Tag(name string) Tag + + // AddCounter measures the number of constraints, variables and coefficients created between two tags + // note that the PlonK statistics are contextual since there is a post-compile phase where linear expressions + // are factorized. That is, measuring 2 times the "repeating" piece of circuit may give less constraints the second time + AddCounter(from, to Tag) + + // ConstantValue returns the big.Int value of v and true if op is a success. + // nil and false if failure. This API returns a boolean to allow for future refactoring + // replacing *big.Int with fr.Element + ConstantValue(v Variable) (*big.Int, bool) + + // CurveID returns the ecc.ID injected by the compiler + Curve() ecc.ID + + // Backend returns the backend.ID injected by the compiler + Backend() backend.ID } -type NewCompiler func(ecc.ID) (Compiler, error) +type NewCompiler func(ecc.ID) (Builder, error) // Compile will generate a ConstraintSystem from the given circuit // @@ -86,7 +127,7 @@ func Compile(curveID ecc.ID, newCompiler NewCompiler, circuit Circuit, opts ...C return compiler.Compile(opt) } -func parseCircuit(compiler Compiler, circuit Circuit) (err error) { +func parseCircuit(builder Builder, circuit Circuit) (err error) { // ensure circuit.Define has pointer receiver if reflect.ValueOf(circuit).Kind() != reflect.Ptr { return errors.New("frontend.Circuit methods must be defined on pointer receiver") @@ -98,9 +139,9 @@ func parseCircuit(compiler Compiler, circuit Circuit) (err error) { if tInput.CanSet() { switch visibility { case schema.Secret: - tInput.Set(reflect.ValueOf(compiler.AddSecretVariable(name))) + tInput.Set(reflect.ValueOf(builder.AddSecretVariable(name))) case schema.Public: - tInput.Set(reflect.ValueOf(compiler.AddPublicVariable(name))) + tInput.Set(reflect.ValueOf(builder.AddPublicVariable(name))) case schema.Unset: return errors.New("can't set val " + name + " visibility is unset") } @@ -115,7 +156,7 @@ func parseCircuit(compiler Compiler, circuit Circuit) (err error) { if err != nil { return err } - compiler.SetSchema(s) + builder.SetSchema(s) // recover from panics to print user-friendlier messages defer func() { @@ -125,7 +166,7 @@ func parseCircuit(compiler Compiler, circuit Circuit) (err error) { }() // call Define() to fill in the Constraints - if err = circuit.Define(compiler); err != nil { + if err = circuit.Define(builder); err != nil { return fmt.Errorf("define circuit: %w", err) } diff --git a/frontend/compiled/variable.go b/frontend/compiled/variable.go index cd284193bb..4c3e2200a3 100644 --- a/frontend/compiled/variable.go +++ b/frontend/compiled/variable.go @@ -18,8 +18,6 @@ import ( "errors" "math/big" "strings" - - "github.com/consensys/gnark/frontend/schema" ) // errNoValue triggered when trying to access a variable that was not allocated @@ -59,12 +57,3 @@ func (v Variable) AssertIsSet() { } } - -// isConstant returns true if the variable is ONE_WIRE * coeff -func (v Variable) IsConstant() bool { - if len(v.LinExp) != 1 { - return false - } - _, vID, visibility := v.LinExp[0].Unpack() - return vID == 0 && visibility == schema.Public -} diff --git a/frontend/cs/r1cs/api.go b/frontend/cs/r1cs/api.go index 5035c2484e..3736998854 100644 --- a/frontend/cs/r1cs/api.go +++ b/frontend/cs/r1cs/api.go @@ -59,8 +59,7 @@ func (system *r1CS) Add(i1, i2 frontend.Variable, in ...frontend.Variable) front func (system *r1CS) Neg(i frontend.Variable) frontend.Variable { vars, _ := system.ToSymbols(i) - if vars[0].IsConstant() { - n := system.constantValue(vars[0]) + if n, ok := system.ConstantValue(vars[0]); ok { n.Neg(n) return system.ToVariable(n) } @@ -100,24 +99,24 @@ func (system *r1CS) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front mul := func(v1, v2 compiled.Variable) compiled.Variable { + n1, v1Constant := system.ConstantValue(v1) + n2, v2Constant := system.ConstantValue(v2) + // v1 and v2 are both unknown, this is the only case we add a constraint - if !v1.IsConstant() && !v2.IsConstant() { + if !v1Constant && !v2Constant { res := system.newInternalVariable() system.Constraints = append(system.Constraints, newR1C(v1, v2, res)) return res } // v1 and v2 are constants, we multiply big.Int values and return resulting constant - if v1.IsConstant() && v2.IsConstant() { - b1 := system.constantValue(v1) - b2 := system.constantValue(v2) - - b1.Mul(b1, b2).Mod(b1, system.CurveID.Info().Fr.Modulus()) - return system.ToVariable(b1).(compiled.Variable) + if v1Constant && v2Constant { + n1.Mul(n1, n2).Mod(n1, system.CurveID.Info().Fr.Modulus()) + return system.ToVariable(n1).(compiled.Variable) } // ensure v2 is the constant - if v1.IsConstant() { + if v1Constant { v1, v2 = v2, v1 } @@ -137,7 +136,7 @@ func (system *r1CS) mulConstant(v1, constant compiled.Variable) compiled.Variabl // multiplying a frontend.Variable by a constant -> we updated the coefficients in the linear expression // leading to that frontend.Variable res := v1.Clone() - lambda := system.constantValue(constant) + lambda, _ := system.ConstantValue(constant) for i, t := range v1.LinExp { cID, vID, visibility := t.Unpack() @@ -166,7 +165,10 @@ func (system *r1CS) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable { v1 := vars[0] v2 := vars[1] - if !v2.IsConstant() { + n1, v1Constant := system.ConstantValue(v1) + n2, v2Constant := system.ConstantValue(v2) + + if !v2Constant { res := system.newInternalVariable() debug := system.AddDebugInfo("div", v1, "/", v2, " == ", res) // note that here we don't ensure that divisor is != 0 @@ -175,20 +177,19 @@ func (system *r1CS) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable { } // v2 is constant - b2 := system.constantValue(v2) - if b2.IsUint64() && b2.Uint64() == 0 { + if n2.IsUint64() && n2.Uint64() == 0 { panic("div by constant(0)") } q := system.CurveID.Info().Fr.Modulus() - b2.ModInverse(b2, q) + n2.ModInverse(n2, q) - if v1.IsConstant() { - b2.Mul(b2, system.constantValue(v1)).Mod(b2, q) - return system.ToVariable(b2) + if v1Constant { + n2.Mul(n2, n1).Mod(n2, q) + return system.ToVariable(n2) } // v1 is not constant - return system.mulConstant(v1, system.ToVariable(b2).(compiled.Variable)) + return system.mulConstant(v1, system.ToVariable(n2).(compiled.Variable)) } // Div returns res = i1 / i2 @@ -198,7 +199,10 @@ func (system *r1CS) Div(i1, i2 frontend.Variable) frontend.Variable { v1 := vars[0] v2 := vars[1] - if !v2.IsConstant() { + n1, v1Constant := system.ConstantValue(v1) + n2, v2Constant := system.ConstantValue(v2) + + if !v2Constant { res := system.newInternalVariable() debug := system.AddDebugInfo("div", v1, "/", v2, " == ", res) v2Inv := system.newInternalVariable() @@ -209,29 +213,26 @@ func (system *r1CS) Div(i1, i2 frontend.Variable) frontend.Variable { } // v2 is constant - b2 := system.constantValue(v2) - if b2.IsUint64() && b2.Uint64() == 0 { + if n2.IsUint64() && n2.Uint64() == 0 { panic("div by constant(0)") } q := system.CurveID.Info().Fr.Modulus() - b2.ModInverse(b2, q) + n2.ModInverse(n2, q) - if v1.IsConstant() { - b2.Mul(b2, system.constantValue(v1)).Mod(b2, q) - return system.ToVariable(b2) + if v1Constant { + n2.Mul(n2, n1).Mod(n2, q) + return system.ToVariable(n2) } // v1 is not constant - return system.mulConstant(v1, system.ToVariable(b2).(compiled.Variable)) + return system.mulConstant(v1, system.ToVariable(n2).(compiled.Variable)) } // Inverse returns res = inverse(v) func (system *r1CS) Inverse(i1 frontend.Variable) frontend.Variable { vars, _ := system.ToSymbols(i1) - if vars[0].IsConstant() { - // c := vars[0].constantValue(cs) - c := system.constantValue(vars[0]) + if c, ok := system.ConstantValue(vars[0]); ok { if c.IsUint64() && c.Uint64() == 0 { panic("inverse by constant(0)") } @@ -272,8 +273,7 @@ func (system *r1CS) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable a := vars[0] // if a is a constant, work with the big int value. - if a.IsConstant() { - c := system.constantValue(a) + if c, ok := system.ConstantValue(a); ok { b := make([]frontend.Variable, nbBits) for i := 0; i < len(b); i++ { b[i] = system.ToVariable(c.Bit(i)) @@ -287,7 +287,7 @@ func (system *r1CS) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable // toBinary is equivalent to ToBinary, exept the returned bits are NOT boolean constrained. func (system *r1CS) toBinary(a compiled.Variable, nbBits int, unsafe bool) []frontend.Variable { - if a.IsConstant() { + if _, ok := system.ConstantValue(a); ok { return system.ToBinary(a, nbBits) } @@ -424,10 +424,10 @@ func (system *r1CS) Select(i0, i1, i2 frontend.Variable) frontend.Variable { // ensures that b is boolean system.AssertIsBoolean(b) + n1, ok1 := system.ConstantValue(vars[1]) + n2, ok2 := system.ConstantValue(vars[2]) - if vars[1].IsConstant() && vars[2].IsConstant() { - n1 := system.constantValue(vars[1]) - n2 := system.constantValue(vars[2]) + if ok1 && ok2 { diff := n1.Sub(n1, n2) res := system.Mul(b, diff) // no constraint is recorded res = system.Add(res, vars[2]) // no constraint is recorded @@ -435,8 +435,7 @@ func (system *r1CS) Select(i0, i1, i2 frontend.Variable) frontend.Variable { } // special case appearing in AssertIsLessOrEq - if vars[1].IsConstant() { - n1 := system.constantValue(vars[1]) + if ok1 { if n1.IsUint64() && n1.Uint64() == 0 { v := system.Sub(1, vars[0]) return system.Mul(v, vars[2]) @@ -486,9 +485,7 @@ func (system *r1CS) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 frontend.Va func (system *r1CS) IsZero(i1 frontend.Variable) frontend.Variable { vars, _ := system.ToSymbols(i1) a := vars[0] - if a.IsConstant() { - // c := a.constantValue(cs) - c := system.constantValue(a) + if c, ok := system.ConstantValue(a); ok { if c.IsUint64() && c.Uint64() == 0 { return system.ToVariable(1) } @@ -545,25 +542,23 @@ func (system *r1CS) Cmp(i1, i2 frontend.Variable) frontend.Variable { // --------------------------------------------------------------------------------------------- // Assertions -// IsConstant returns true if v is a constant known at compile time -func (system *r1CS) IsConstant(v frontend.Variable) bool { - if _v, ok := v.(compiled.Variable); ok { - return _v.IsConstant() - } - // it's not a wire, it's another golang type, we consider it constant. - // TODO we may want to use the struct parser to ensure this frontend.Variable interface doesn't contain fields which are - // frontend.Variable - return true -} - // ConstantValue returns the big.Int value of v. // Will panic if v.IsConstant() == false -func (system *r1CS) ConstantValue(v frontend.Variable) *big.Int { +func (system *r1CS) ConstantValue(v frontend.Variable) (*big.Int, bool) { if _v, ok := v.(compiled.Variable); ok { - return system.constantValue(_v) + _v.AssertIsSet() + + if len(_v.LinExp) != 1 { + return nil, false + } + cID, vID, visibility := _v.LinExp[0].Unpack() + if !(vID == 0 && visibility == schema.Public) { + return nil, false + } + return new(big.Int).Set(&system.st.Coeffs[cID]), true } r := utils.FromInterface(v) - return &r + return &r, true } func (system *r1CS) Backend() backend.ID { @@ -781,3 +776,7 @@ func (system *r1CS) negateLinExp(l []compiled.Term) []compiled.Term { } return res } + +func (system *r1CS) Compiler() frontend.Compiler { + return system +} diff --git a/frontend/cs/r1cs/assertions.go b/frontend/cs/r1cs/assertions.go index c77e0e9016..6e1c6661f9 100644 --- a/frontend/cs/r1cs/assertions.go +++ b/frontend/cs/r1cs/assertions.go @@ -47,8 +47,7 @@ func (system *r1CS) AssertIsBoolean(i1 frontend.Variable) { vars, _ := system.ToSymbols(i1) v := vars[0] - if v.IsConstant() { - c := system.constantValue(v) + if c, ok := system.ConstantValue(v); ok { if !(c.IsUint64() && (c.Uint64() == 0 || c.Uint64() == 1)) { panic(fmt.Sprintf("assertIsBoolean failed: constant(%s)", c.String())) } diff --git a/frontend/cs/r1cs/conversion.go b/frontend/cs/r1cs/conversion.go index c966e22d2b..f797b29d85 100644 --- a/frontend/cs/r1cs/conversion.go +++ b/frontend/cs/r1cs/conversion.go @@ -157,6 +157,9 @@ HINTLOOP: } func (cs *r1CS) SetSchema(s *schema.Schema) { + if cs.Schema != nil { + panic("SetSchema called multiple times") + } cs.Schema = s } diff --git a/frontend/cs/r1cs/r1cs.go b/frontend/cs/r1cs/r1cs.go index 29ae8a1f65..da7ff0daa0 100644 --- a/frontend/cs/r1cs/r1cs.go +++ b/frontend/cs/r1cs/r1cs.go @@ -31,7 +31,7 @@ import ( "github.com/consensys/gnark/frontend/schema" ) -func NewCompiler(curve ecc.ID) (frontend.Compiler, error) { +func NewCompiler(curve ecc.ID) (frontend.Builder, error) { return newR1CS(curve), nil } @@ -120,16 +120,6 @@ func (system *r1CS) AddSecretVariable(name string) frontend.Variable { return res } -// func (v *variable) constantValue(system *R1CS) *big.Int { -func (system *r1CS) constantValue(v compiled.Variable) *big.Int { - // TODO this might be a good place to start hunting useless allocations. - // maybe through a big.Int pool. - if !v.IsConstant() { - panic("can't get big.Int value on a non-constant variable") - } - return new(big.Int).Set(&system.st.Coeffs[v.LinExp[0].CoeffID()]) -} - func (system *r1CS) one() compiled.Variable { return compiled.Variable{ LinExp: compiled.LinearExpression{compiled.Pack(0, compiled.CoeffIdOne, schema.Public)}, @@ -201,7 +191,7 @@ func (system *r1CS) setCoeff(v compiled.Term, coeff *big.Int) compiled.Term { // This is useful in scenarios where a variable is known to be boolean through a constraint // that is not api.AssertIsBoolean. If v is a constant, this is a no-op. func (system *r1CS) MarkBoolean(v frontend.Variable) { - if system.IsConstant(v) { + if _, ok := system.ConstantValue(v); ok { return } // v is a linear expression @@ -220,8 +210,7 @@ func (system *r1CS) MarkBoolean(v frontend.Variable) { // Use with care; variable may not have been **constrained** to be boolean // This returns true if the v is a constant and v == 0 || v == 1. func (system *r1CS) IsBoolean(v frontend.Variable) bool { - if system.IsConstant(v) { - b := system.ConstantValue(v) + if b, ok := system.ConstantValue(v); ok { return b.IsUint64() && b.Uint64() <= 1 } // v is a linear expression diff --git a/frontend/cs/scs/api.go b/frontend/cs/scs/api.go index 592c45392a..d8ba0d6fd0 100644 --- a/frontend/cs/scs/api.go +++ b/frontend/cs/scs/api.go @@ -72,10 +72,10 @@ func (system *sparseR1CS) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) // Neg returns -i func (system *sparseR1CS) Neg(i1 frontend.Variable) frontend.Variable { - if system.IsConstant(i1) { - k := system.ConstantValue(i1) - k.Neg(k) - return *k + if n, ok := system.ConstantValue(i1); ok { + n.Neg(n) + // TODO shouldn't that go through variable conversion? + return *n } else { v := i1.(compiled.Term) c, _, _ := v.Unpack() @@ -622,12 +622,14 @@ func (system *sparseR1CS) IsConstant(v frontend.Variable) bool { // ConstantValue returns the big.Int value of v. It // panics if v.IsConstant() == false -func (system *sparseR1CS) ConstantValue(v frontend.Variable) *big.Int { - if !system.IsConstant(v) { - panic("v should be a constant") +func (system *sparseR1CS) ConstantValue(v frontend.Variable) (*big.Int, bool) { + switch t := v.(type) { + case compiled.Term: + return nil, false + default: + res := utils.FromInterface(t) + return &res, true } - res := utils.FromInterface(v) - return &res } func (system *sparseR1CS) Backend() backend.ID { @@ -694,3 +696,7 @@ func (system *sparseR1CS) splitProd(acc compiled.Term, r []compiled.Term) compil system.addPlonkConstraint(acc, r[0], o, compiled.CoeffIdZero, compiled.CoeffIdZero, cl, cr, compiled.CoeffIdMinusOne, compiled.CoeffIdZero) return system.splitProd(o, r[1:]) } + +func (system *sparseR1CS) Compiler() frontend.Compiler { + return system +} diff --git a/frontend/cs/scs/conversion.go b/frontend/cs/scs/conversion.go index c68535963a..4a4dccc4e4 100644 --- a/frontend/cs/scs/conversion.go +++ b/frontend/cs/scs/conversion.go @@ -146,6 +146,9 @@ HINTLOOP: } func (cs *sparseR1CS) SetSchema(s *schema.Schema) { + if cs.Schema != nil { + panic("SetSchema called multiple times") + } cs.Schema = s } diff --git a/frontend/cs/scs/sparse_r1cs.go b/frontend/cs/scs/sparse_r1cs.go index 24a2950cd3..06f78c0814 100644 --- a/frontend/cs/scs/sparse_r1cs.go +++ b/frontend/cs/scs/sparse_r1cs.go @@ -31,7 +31,7 @@ import ( "github.com/consensys/gnark/frontend/schema" ) -func NewCompiler(curve ecc.ID) (frontend.Compiler, error) { +func NewCompiler(curve ecc.ID) (frontend.Builder, error) { return newSparseR1CS(curve), nil } @@ -169,8 +169,7 @@ func (system *sparseR1CS) zero() compiled.Term { // Use with care; variable may not have been **constrained** to be boolean // This returns true if the v is a constant and v == 0 || v == 1. func (system *sparseR1CS) IsBoolean(v frontend.Variable) bool { - if system.IsConstant(v) { - b := system.ConstantValue(v) + if b, ok := system.ConstantValue(v); ok { return b.IsUint64() && b.Uint64() <= 1 } _, ok := system.mtBooleans[int(v.(compiled.Term))] diff --git a/internal/backend/circuits/hint.go b/internal/backend/circuits/hint.go index 3e64e8cf97..3b9e99294e 100644 --- a/internal/backend/circuits/hint.go +++ b/internal/backend/circuits/hint.go @@ -14,7 +14,7 @@ type hintCircuit struct { } func (circuit *hintCircuit) Define(api frontend.API) error { - res, err := api.NewHint(mulBy7, 1, circuit.A) + res, err := api.Compiler().NewHint(mulBy7, 1, circuit.A) if err != nil { return fmt.Errorf("mulBy7: %w", err) } @@ -23,7 +23,7 @@ func (circuit *hintCircuit) Define(api frontend.API) error { api.AssertIsEqual(a7, _a7) api.AssertIsEqual(a7, circuit.B) - res, err = api.NewHint(make3, 1) + res, err = api.Compiler().NewHint(make3, 1) if err != nil { return fmt.Errorf("make3: %w", err) } @@ -39,7 +39,7 @@ type vectorDoubleCircuit struct { } func (c *vectorDoubleCircuit) Define(api frontend.API) error { - res, err := api.NewHint(dvHint, len(c.B), c.A...) + res, err := api.Compiler().NewHint(dvHint, len(c.B), c.A...) if err != nil { return fmt.Errorf("double newhint: %w", err) } diff --git a/std/algebra/sw_bls12377/g1.go b/std/algebra/sw_bls12377/g1.go index 208f299d13..150722b912 100644 --- a/std/algebra/sw_bls12377/g1.go +++ b/std/algebra/sw_bls12377/g1.go @@ -199,8 +199,8 @@ func (p *G1Affine) Double(api frontend.API, p1 G1Affine) *G1Affine { // then the compiled circuit depends on s. If it is variable type, then // the circuit is independent of the inputs. func (P *G1Affine) ScalarMul(api frontend.API, Q G1Affine, s interface{}) *G1Affine { - if api.IsConstant(s) { - return P.constScalarMul(api, Q, api.ConstantValue(s)) + if n, ok := api.Compiler().ConstantValue(s); ok { + return P.constScalarMul(api, Q, n) } else { return P.varScalarMul(api, Q, s) } @@ -249,12 +249,12 @@ func (P *G1Affine) varScalarMul(api frontend.API, Q G1Affine, s frontend.Variabl // points and the operations on the points are performed on the `inner` // curve of the outer curve. We require some parameters from the inner // curve. - cc := innerCurve(api.Curve()) + cc := innerCurve(api.Compiler().Curve()) // the hints allow to decompose the scalar s into s1 and s2 such that // s1 + λ * s2 == s mod r, // where λ is third root of one in 𝔽_r. - sd, err := api.NewHint(scalarDecompositionHintBLS12377, 3, s) + sd, err := api.Compiler().NewHint(scalarDecompositionHintBLS12377, 3, s) if err != nil { // err is non-nil only for invalid number of inputs panic(err) @@ -343,7 +343,7 @@ func (P *G1Affine) constScalarMul(api frontend.API, Q G1Affine, s *big.Int) *G1A // bits are constant and here it makes sense to use the table in the main // loop. var Acc, negQ, negPhiQ, phiQ G1Affine - cc := innerCurve(api.Curve()) + cc := innerCurve(api.Compiler().Curve()) s.Mod(s, cc.fr) cc.phi(api, &phiQ, &Q) diff --git a/std/algebra/sw_bls24315/g1.go b/std/algebra/sw_bls24315/g1.go index e18f11550b..7bcd1d9c24 100644 --- a/std/algebra/sw_bls24315/g1.go +++ b/std/algebra/sw_bls24315/g1.go @@ -199,8 +199,8 @@ func (p *G1Affine) Double(api frontend.API, p1 G1Affine) *G1Affine { // then the compiled circuit depends on s. If it is variable type, then // the circuit is independent of the inputs. func (P *G1Affine) ScalarMul(api frontend.API, Q G1Affine, s interface{}) *G1Affine { - if api.IsConstant(s) { - return P.constScalarMul(api, Q, api.ConstantValue(s)) + if n, ok := api.Compiler().ConstantValue(s); ok { + return P.constScalarMul(api, Q, n) } else { return P.varScalarMul(api, Q, s) } @@ -249,12 +249,12 @@ func (P *G1Affine) varScalarMul(api frontend.API, Q G1Affine, s frontend.Variabl // points and the operations on the points are performed on the `inner` // curve of the outer curve. We require some parameters from the inner // curve. - cc := innerCurve(api.Curve()) + cc := innerCurve(api.Compiler().Curve()) // the hints allow to decompose the scalar s into s1 and s2 such that // s1 + λ * s2 == s mod r, // where λ is third root of one in 𝔽_r. - sd, err := api.NewHint(scalarDecompositionHintBLS24315, 3, s) + sd, err := api.Compiler().NewHint(scalarDecompositionHintBLS24315, 3, s) if err != nil { // err is non-nil only for invalid number of inputs panic(err) @@ -343,7 +343,7 @@ func (P *G1Affine) constScalarMul(api frontend.API, Q G1Affine, s *big.Int) *G1A // bits are constant and here it makes sense to use the table in the main // loop. var Acc, negQ, negPhiQ, phiQ G1Affine - cc := innerCurve(api.Curve()) + cc := innerCurve(api.Compiler().Curve()) s.Mod(s, cc.fr) cc.phi(api, &phiQ, &Q) diff --git a/std/algebra/twistededwards/bandersnatch/point_test.go b/std/algebra/twistededwards/bandersnatch/point_test.go index 94af797d4e..8b78d922ff 100644 --- a/std/algebra/twistededwards/bandersnatch/point_test.go +++ b/std/algebra/twistededwards/bandersnatch/point_test.go @@ -34,7 +34,7 @@ type mustBeOnCurve struct { func (circuit *mustBeOnCurve) Define(api frontend.API) error { // get edwards curve params - params, err := NewEdCurve(api.Curve()) + params, err := NewEdCurve(api.Compiler().Curve()) if err != nil { return err } @@ -69,7 +69,7 @@ type add struct { func (circuit *add) Define(api frontend.API) error { // get edwards curve params - params, err := NewEdCurve(api.Curve()) + params, err := NewEdCurve(api.Compiler().Curve()) if err != nil { return err } @@ -125,7 +125,7 @@ type addGeneric struct { func (circuit *addGeneric) Define(api frontend.API) error { // get edwards curve params - params, err := NewEdCurve(api.Curve()) + params, err := NewEdCurve(api.Compiler().Curve()) if err != nil { return err } @@ -180,7 +180,7 @@ type double struct { func (circuit *double) Define(api frontend.API) error { // get edwards curve params - params, err := NewEdCurve(api.Curve()) + params, err := NewEdCurve(api.Compiler().Curve()) if err != nil { return err } @@ -228,7 +228,7 @@ type scalarMulFixed struct { func (circuit *scalarMulFixed) Define(api frontend.API) error { // get edwards curve params - params, err := NewEdCurve(api.Curve()) + params, err := NewEdCurve(api.Compiler().Curve()) if err != nil { return err } @@ -279,7 +279,7 @@ type scalarMulGeneric struct { func (circuit *scalarMulGeneric) Define(api frontend.API) error { // get edwards curve params - params, err := NewEdCurve(api.Curve()) + params, err := NewEdCurve(api.Compiler().Curve()) if err != nil { return err } diff --git a/std/algebra/twistededwards/point_test.go b/std/algebra/twistededwards/point_test.go index d9387e94fd..14dcb47b08 100644 --- a/std/algebra/twistededwards/point_test.go +++ b/std/algebra/twistededwards/point_test.go @@ -39,7 +39,7 @@ type mustBeOnCurve struct { func (circuit *mustBeOnCurve) Define(api frontend.API) error { // get edwards curve params - params, err := NewEdCurve(api.Curve()) + params, err := NewEdCurve(api.Compiler().Curve()) if err != nil { return err } @@ -74,7 +74,7 @@ type add struct { func (circuit *add) Define(api frontend.API) error { // get edwards curve params - params, err := NewEdCurve(api.Curve()) + params, err := NewEdCurve(api.Compiler().Curve()) if err != nil { return err } @@ -130,7 +130,7 @@ type addGeneric struct { func (circuit *addGeneric) Define(api frontend.API) error { // get edwards curve params - params, err := NewEdCurve(api.Curve()) + params, err := NewEdCurve(api.Compiler().Curve()) if err != nil { return err } @@ -271,7 +271,7 @@ type double struct { func (circuit *double) Define(api frontend.API) error { // get edwards curve params - params, err := NewEdCurve(api.Curve()) + params, err := NewEdCurve(api.Compiler().Curve()) if err != nil { return err } @@ -372,7 +372,7 @@ type scalarMulFixed struct { func (circuit *scalarMulFixed) Define(api frontend.API) error { // get edwards curve params - params, err := NewEdCurve(api.Curve()) + params, err := NewEdCurve(api.Compiler().Curve()) if err != nil { return err } @@ -473,7 +473,7 @@ type scalarMulGeneric struct { func (circuit *scalarMulGeneric) Define(api frontend.API) error { // get edwards curve params - params, err := NewEdCurve(api.Curve()) + params, err := NewEdCurve(api.Compiler().Curve()) if err != nil { return err } @@ -608,7 +608,7 @@ type doubleScalarMulGeneric struct { func (circuit *doubleScalarMulGeneric) Define(api frontend.API) error { // get edwards curve params - params, err := NewEdCurve(api.Curve()) + params, err := NewEdCurve(api.Compiler().Curve()) if err != nil { return err } diff --git a/std/fiat-shamir/transcript_test.go b/std/fiat-shamir/transcript_test.go index 672b202731..2911e84874 100644 --- a/std/fiat-shamir/transcript_test.go +++ b/std/fiat-shamir/transcript_test.go @@ -43,7 +43,7 @@ func (circuit *FiatShamirCircuit) Define(api frontend.API) error { } // get the challenges - alpha, beta, gamma := getChallenges(api.Curve()) + alpha, beta, gamma := getChallenges(api.Compiler().Curve()) // New transcript with 3 challenges to be derived tsSnark := NewTranscript(api, &hSnark, alpha, beta, gamma) diff --git a/std/hash/mimc/mimc.go b/std/hash/mimc/mimc.go index 90debe3d7e..36a9d07567 100644 --- a/std/hash/mimc/mimc.go +++ b/std/hash/mimc/mimc.go @@ -36,7 +36,7 @@ type MiMC struct { // NewMiMC returns a MiMC instance, than can be used in a gnark circuit func NewMiMC(api frontend.API) (MiMC, error) { - if constructor, ok := newMimc[api.Curve()]; ok { + if constructor, ok := newMimc[api.Compiler().Curve()]; ok { return constructor(api), nil } return MiMC{}, errors.New("unknown curve id") diff --git a/std/signature/eddsa/eddsa_test.go b/std/signature/eddsa/eddsa_test.go index 84636c0a4b..9bc4265074 100644 --- a/std/signature/eddsa/eddsa_test.go +++ b/std/signature/eddsa/eddsa_test.go @@ -138,7 +138,7 @@ func parsePoint(id ecc.ID, buf []byte) ([]byte, []byte) { func (circuit *eddsaCircuit) Define(api frontend.API) error { - params, err := twistededwards.NewEdCurve(api.Curve()) + params, err := twistededwards.NewEdCurve(api.Compiler().Curve()) if err != nil { return err } diff --git a/test/engine.go b/test/engine.go index f7703758a9..4819834e50 100644 --- a/test/engine.go +++ b/test/engine.go @@ -368,9 +368,20 @@ func (e *engine) IsConstant(v frontend.Variable) bool { // ConstantValue returns the big.Int value of v // will panic if v.IsConstant() == false -func (e *engine) ConstantValue(v frontend.Variable) *big.Int { +func (e *engine) ConstantValue(v frontend.Variable) (*big.Int, bool) { r := e.toBigInt(v) - return &r + return &r, true +} + +func (e *engine) IsBoolean(v frontend.Variable) bool { + r := e.toBigInt(v) + return r.IsUint64() && r.Uint64() <= 1 +} + +func (e *engine) MarkBoolean(v frontend.Variable) { + if !e.IsBoolean(v) { + panic("mark boolean a non-boolean value") + } } func (e *engine) Tag(name string) frontend.Tag { @@ -462,3 +473,7 @@ func copyWitness(to, from frontend.Circuit) { _, _ = schema.Parse(to, tVariable, setHandler) } + +func (e *engine) Compiler() frontend.Compiler { + return e +} diff --git a/test/engine_test.go b/test/engine_test.go index 851e5ef972..59808beae1 100644 --- a/test/engine_test.go +++ b/test/engine_test.go @@ -15,22 +15,22 @@ type hintCircuit struct { } func (circuit *hintCircuit) Define(api frontend.API) error { - res, err := api.NewHint(hint.IthBit, 1, circuit.A, 3) + res, err := api.Compiler().NewHint(hint.IthBit, 1, circuit.A, 3) if err != nil { return fmt.Errorf("IthBit circuitA 3: %w", err) } a3b := res[0] - res, err = api.NewHint(hint.IthBit, 1, circuit.A, 25) + res, err = api.Compiler().NewHint(hint.IthBit, 1, circuit.A, 25) if err != nil { return fmt.Errorf("IthBit circuitA 25: %w", err) } a25b := res[0] - res, err = api.NewHint(hint.IsZero, 1, circuit.A) + res, err = api.Compiler().NewHint(hint.IsZero, 1, circuit.A) if err != nil { return fmt.Errorf("IsZero CircuitA: %w", err) } aisZero := res[0] - res, err = api.NewHint(hint.IsZero, 1, circuit.B) + res, err = api.Compiler().NewHint(hint.IsZero, 1, circuit.B) if err != nil { return fmt.Errorf("IsZero, CircuitB") } From ba4958b8ad9d762ba615914525103709792ecb92 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 28 Feb 2022 13:28:47 -0600 Subject: [PATCH 12/20] refactor: frontend/cs subpackages to match new interfaces --- frontend/cs/r1cs/api.go | 245 ++---- .../r1cs/{assertions.go => api_assertions.go} | 32 +- frontend/cs/r1cs/compiler.go | 721 ++++++++++++++++++ frontend/cs/r1cs/conversion.go | 260 ------- frontend/cs/r1cs/r1cs.go | 341 --------- frontend/cs/r1cs/r1cs_test.go | 2 +- frontend/cs/scs/api.go | 196 +---- .../scs/{assertions.go => api_assertions.go} | 12 +- frontend/cs/scs/compiler.go | 682 +++++++++++++++++ frontend/cs/scs/conversion.go | 253 ------ frontend/cs/scs/sparse_r1cs.go | 296 ------- 11 files changed, 1498 insertions(+), 1542 deletions(-) rename frontend/cs/r1cs/{assertions.go => api_assertions.go} (83%) create mode 100644 frontend/cs/r1cs/compiler.go delete mode 100644 frontend/cs/r1cs/conversion.go delete mode 100644 frontend/cs/r1cs/r1cs.go rename frontend/cs/scs/{assertions.go => api_assertions.go} (92%) create mode 100644 frontend/cs/scs/compiler.go delete mode 100644 frontend/cs/scs/conversion.go delete mode 100644 frontend/cs/scs/sparse_r1cs.go diff --git a/frontend/cs/r1cs/api.go b/frontend/cs/r1cs/api.go index 3736998854..cc6287412c 100644 --- a/frontend/cs/r1cs/api.go +++ b/frontend/cs/r1cs/api.go @@ -25,22 +25,20 @@ import ( "strconv" "strings" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/utils" ) // --------------------------------------------------------------------------------------------- // Arithmetic // Add returns res = i1+i2+...in -func (system *r1CS) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { +func (system *compiler) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { // extract frontend.Variables from input - vars, s := system.ToSymbols(append([]frontend.Variable{i1, i2}, in...)...) + vars, s := system.toVariables(append([]frontend.Variable{i1, i2}, in...)...) // allocate resulting frontend.Variable res := compiled.Variable{LinExp: make([]compiled.Term, 0, s)} @@ -56,12 +54,12 @@ func (system *r1CS) Add(i1, i2 frontend.Variable, in ...frontend.Variable) front } // Neg returns -i -func (system *r1CS) Neg(i frontend.Variable) frontend.Variable { - vars, _ := system.ToSymbols(i) +func (system *compiler) Neg(i frontend.Variable) frontend.Variable { + vars, _ := system.toVariables(i) if n, ok := system.ConstantValue(vars[0]); ok { n.Neg(n) - return system.ToVariable(n) + return system.toVariable(n) } res := compiled.Variable{LinExp: system.negateLinExp(vars[0].LinExp)} @@ -70,10 +68,10 @@ func (system *r1CS) Neg(i frontend.Variable) frontend.Variable { } // Sub returns res = i1 - i2 -func (system *r1CS) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { +func (system *compiler) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { // extract frontend.Variables from input - vars, s := system.ToSymbols(append([]frontend.Variable{i1, i2}, in...)...) + vars, s := system.toVariables(append([]frontend.Variable{i1, i2}, in...)...) // allocate resulting frontend.Variable res := compiled.Variable{ @@ -94,8 +92,8 @@ func (system *r1CS) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) front } // Mul returns res = i1 * i2 * ... in -func (system *r1CS) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - vars, _ := system.ToSymbols(append([]frontend.Variable{i1, i2}, in...)...) +func (system *compiler) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + vars, _ := system.toVariables(append([]frontend.Variable{i1, i2}, in...)...) mul := func(v1, v2 compiled.Variable) compiled.Variable { @@ -112,7 +110,7 @@ func (system *r1CS) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front // v1 and v2 are constants, we multiply big.Int values and return resulting constant if v1Constant && v2Constant { n1.Mul(n1, n2).Mod(n1, system.CurveID.Info().Fr.Modulus()) - return system.ToVariable(n1).(compiled.Variable) + return system.toVariable(n1).(compiled.Variable) } // ensure v2 is the constant @@ -132,7 +130,7 @@ func (system *r1CS) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } -func (system *r1CS) mulConstant(v1, constant compiled.Variable) compiled.Variable { +func (system *compiler) mulConstant(v1, constant compiled.Variable) compiled.Variable { // multiplying a frontend.Variable by a constant -> we updated the coefficients in the linear expression // leading to that frontend.Variable res := v1.Clone() @@ -159,8 +157,8 @@ func (system *r1CS) mulConstant(v1, constant compiled.Variable) compiled.Variabl return res } -func (system *r1CS) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable { - vars, _ := system.ToSymbols(i1, i2) +func (system *compiler) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable { + vars, _ := system.toVariables(i1, i2) v1 := vars[0] v2 := vars[1] @@ -185,16 +183,16 @@ func (system *r1CS) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable { if v1Constant { n2.Mul(n2, n1).Mod(n2, q) - return system.ToVariable(n2) + return system.toVariable(n2) } // v1 is not constant - return system.mulConstant(v1, system.ToVariable(n2).(compiled.Variable)) + return system.mulConstant(v1, system.toVariable(n2).(compiled.Variable)) } // Div returns res = i1 / i2 -func (system *r1CS) Div(i1, i2 frontend.Variable) frontend.Variable { - vars, _ := system.ToSymbols(i1, i2) +func (system *compiler) Div(i1, i2 frontend.Variable) frontend.Variable { + vars, _ := system.toVariables(i1, i2) v1 := vars[0] v2 := vars[1] @@ -221,16 +219,16 @@ func (system *r1CS) Div(i1, i2 frontend.Variable) frontend.Variable { if v1Constant { n2.Mul(n2, n1).Mod(n2, q) - return system.ToVariable(n2) + return system.toVariable(n2) } // v1 is not constant - return system.mulConstant(v1, system.ToVariable(n2).(compiled.Variable)) + return system.mulConstant(v1, system.toVariable(n2).(compiled.Variable)) } // Inverse returns res = inverse(v) -func (system *r1CS) Inverse(i1 frontend.Variable) frontend.Variable { - vars, _ := system.ToSymbols(i1) +func (system *compiler) Inverse(i1 frontend.Variable) frontend.Variable { + vars, _ := system.toVariables(i1) if c, ok := system.ConstantValue(vars[0]); ok { if c.IsUint64() && c.Uint64() == 0 { @@ -238,7 +236,7 @@ func (system *r1CS) Inverse(i1 frontend.Variable) frontend.Variable { } c.ModInverse(c, system.CurveID.Info().Fr.Modulus()) - return system.ToVariable(c) + return system.toVariable(c) } // allocate resulting frontend.Variable @@ -258,7 +256,7 @@ func (system *r1CS) Inverse(i1 frontend.Variable) frontend.Variable { // n default value is fr.Bits the number of bits needed to represent a field element // // The result in in little endian (first bit= lsb) -func (system *r1CS) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable { +func (system *compiler) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable { // nbBits nbBits := system.BitLen() @@ -269,14 +267,14 @@ func (system *r1CS) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable } } - vars, _ := system.ToSymbols(i1) + vars, _ := system.toVariables(i1) a := vars[0] // if a is a constant, work with the big int value. if c, ok := system.ConstantValue(a); ok { b := make([]frontend.Variable, nbBits) for i := 0; i < len(b); i++ { - b[i] = system.ToVariable(c.Bit(i)) + b[i] = system.toVariable(c.Bit(i)) } return b } @@ -285,7 +283,7 @@ func (system *r1CS) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable } // toBinary is equivalent to ToBinary, exept the returned bits are NOT boolean constrained. -func (system *r1CS) toBinary(a compiled.Variable, nbBits int, unsafe bool) []frontend.Variable { +func (system *compiler) toBinary(a compiled.Variable, nbBits int, unsafe bool) []frontend.Variable { if _, ok := system.ConstantValue(a); ok { return system.ToBinary(a, nbBits) @@ -329,8 +327,8 @@ func (system *r1CS) toBinary(a compiled.Variable, nbBits int, unsafe bool) []fro } // FromBinary packs b, seen as a fr.Element in little endian -func (system *r1CS) FromBinary(_b ...frontend.Variable) frontend.Variable { - b, _ := system.ToSymbols(_b...) +func (system *compiler) FromBinary(_b ...frontend.Variable) frontend.Variable { + b, _ := system.toVariables(_b...) // ensure inputs are set for i := 0; i < len(b); i++ { @@ -340,7 +338,7 @@ func (system *r1CS) FromBinary(_b ...frontend.Variable) frontend.Variable { // res = Σ (2**i * b[i]) var res, v frontend.Variable - res = system.ToVariable(0) // no constraint is recorded + res = system.toVariable(0) // no constraint is recorded var c big.Int c.SetUint64(1) @@ -357,9 +355,9 @@ func (system *r1CS) FromBinary(_b ...frontend.Variable) frontend.Variable { } // Xor compute the XOR between two frontend.Variables -func (system *r1CS) Xor(_a, _b frontend.Variable) frontend.Variable { +func (system *compiler) Xor(_a, _b frontend.Variable) frontend.Variable { - vars, _ := system.ToSymbols(_a, _b) + vars, _ := system.toVariables(_a, _b) a := vars[0] b := vars[1] @@ -379,8 +377,8 @@ func (system *r1CS) Xor(_a, _b frontend.Variable) frontend.Variable { } // Or compute the OR between two frontend.Variables -func (system *r1CS) Or(_a, _b frontend.Variable) frontend.Variable { - vars, _ := system.ToSymbols(_a, _b) +func (system *compiler) Or(_a, _b frontend.Variable) frontend.Variable { + vars, _ := system.toVariables(_a, _b) a := vars[0] b := vars[1] @@ -399,8 +397,8 @@ func (system *r1CS) Or(_a, _b frontend.Variable) frontend.Variable { } // And compute the AND between two frontend.Variables -func (system *r1CS) And(_a, _b frontend.Variable) frontend.Variable { - vars, _ := system.ToSymbols(_a, _b) +func (system *compiler) And(_a, _b frontend.Variable) frontend.Variable { + vars, _ := system.toVariables(_a, _b) a := vars[0] b := vars[1] @@ -417,9 +415,9 @@ func (system *r1CS) And(_a, _b frontend.Variable) frontend.Variable { // Conditionals // Select if i0 is true, yields i1 else yields i2 -func (system *r1CS) Select(i0, i1, i2 frontend.Variable) frontend.Variable { +func (system *compiler) Select(i0, i1, i2 frontend.Variable) frontend.Variable { - vars, _ := system.ToSymbols(i0, i1, i2) + vars, _ := system.toVariables(i0, i1, i2) b := vars[0] // ensures that b is boolean @@ -451,8 +449,8 @@ func (system *r1CS) Select(i0, i1, i2 frontend.Variable) frontend.Variable { // Lookup2 performs a 2-bit lookup between i1, i2, i3, i4 based on bits b0 // and b1. Returns i0 if b0=b1=0, i1 if b0=1 and b1=0, i2 if b0=0 and b1=1 // and i3 if b0=b1=1. -func (system *r1CS) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 frontend.Variable) frontend.Variable { - vars, _ := system.ToSymbols(b0, b1, i0, i1, i2, i3) +func (system *compiler) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 frontend.Variable) frontend.Variable { + vars, _ := system.toVariables(b0, b1, i0, i1, i2, i3) s0, s1 := vars[0], vars[1] in0, in1, in2, in3 := vars[2], vars[3], vars[4], vars[5] @@ -482,14 +480,14 @@ func (system *r1CS) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 frontend.Va } // IsZero returns 1 if i1 is zero, 0 otherwise -func (system *r1CS) IsZero(i1 frontend.Variable) frontend.Variable { - vars, _ := system.ToSymbols(i1) +func (system *compiler) IsZero(i1 frontend.Variable) frontend.Variable { + vars, _ := system.toVariables(i1) a := vars[0] if c, ok := system.ConstantValue(a); ok { if c.IsUint64() && c.Uint64() == 0 { - return system.ToVariable(1) + return system.toVariable(1) } - return system.ToVariable(0) + return system.toVariable(0) } debug := system.AddDebugInfo("isZero", a) @@ -505,7 +503,7 @@ func (system *r1CS) IsZero(i1 frontend.Variable) frontend.Variable { panic(err) } m := res[0] - system.addConstraint(newR1C(a, m, system.ToVariable(0)), debug) + system.addConstraint(newR1C(a, m, system.toVariable(0)), debug) system.AssertIsBoolean(m) ma := system.Add(m, a) @@ -514,13 +512,13 @@ func (system *r1CS) IsZero(i1 frontend.Variable) frontend.Variable { } // Cmp returns 1 if i1>i2, 0 if i1=i2, -1 if i1= 0; i-- { @@ -539,38 +537,12 @@ func (system *r1CS) Cmp(i1, i2 frontend.Variable) frontend.Variable { return res } -// --------------------------------------------------------------------------------------------- -// Assertions - -// ConstantValue returns the big.Int value of v. -// Will panic if v.IsConstant() == false -func (system *r1CS) ConstantValue(v frontend.Variable) (*big.Int, bool) { - if _v, ok := v.(compiled.Variable); ok { - _v.AssertIsSet() - - if len(_v.LinExp) != 1 { - return nil, false - } - cID, vID, visibility := _v.LinExp[0].Unpack() - if !(vID == 0 && visibility == schema.Public) { - return nil, false - } - return new(big.Int).Set(&system.st.Coeffs[cID]), true - } - r := utils.FromInterface(v) - return &r, true -} - -func (system *r1CS) Backend() backend.ID { - return backend.GROTH16 -} - // Println enables circuit debugging and behaves almost like fmt.Println() // // the print will be done once the R1CS.Solve() method is executed // // if one of the input is a variable, its value will be resolved avec R1CS.Solve() method is called -func (system *r1CS) Println(a ...frontend.Variable) { +func (system *compiler) Println(a ...frontend.Variable) { var sbb strings.Builder // prefix log line with file.go:line @@ -647,125 +619,8 @@ func printArg(log *compiled.LogEntry, sbb *strings.Builder, a frontend.Variable) sbb.WriteByte('}') } -// Tag creates a tag at a given place in a circuit. The state of the tag may contain informations needed to -// measure constraints, variables and coefficients creations through AddCounter -func (system *r1CS) Tag(name string) frontend.Tag { - _, file, line, _ := runtime.Caller(1) - - return frontend.Tag{ - Name: fmt.Sprintf("%s[%s:%d]", name, filepath.Base(file), line), - VID: system.NbInternalVariables, - CID: len(system.Constraints), - } -} - -// AddCounter measures the number of constraints, variables and coefficients created between two tags -func (system *r1CS) AddCounter(from, to frontend.Tag) { - system.Counters = append(system.Counters, compiled.Counter{ - From: from.Name, - To: to.Name, - NbVariables: to.VID - from.VID, - NbConstraints: to.CID - from.CID, - CurveID: system.CurveID, - BackendID: backend.GROTH16, - }) -} - -// NewHint initializes internal variables whose value will be evaluated using -// the provided hint function at run time from the inputs. Inputs must be either -// variables or convertible to *big.Int. The function returns an error if the -// number of inputs is not compatible with f. -// -// The hint function is provided at the proof creation time and is not embedded -// into the circuit. From the backend point of view, the variable returned by -// the hint function is equivalent to the user-supplied witness, but its actual -// value is assigned by the solver, not the caller. -// -// No new constraints are added to the newly created wire and must be added -// manually in the circuit. Failing to do so leads to solver failure. -func (system *r1CS) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { - - if nbOutputs <= 0 { - return nil, fmt.Errorf("hint function must return at least one output") - } - hintInputs := make([]interface{}, len(inputs)) - - // ensure inputs are set and pack them in a []uint64 - for i, in := range inputs { - switch t := in.(type) { - case compiled.Variable: - tmp := t.Clone() - hintInputs[i] = tmp.LinExp - case compiled.LinearExpression: - tmp := make(compiled.LinearExpression, len(t)) - copy(tmp, t) - hintInputs[i] = tmp - default: - hintInputs[i] = utils.FromInterface(t) - } - } - - // prepare wires - varIDs := make([]int, nbOutputs) - res := make([]frontend.Variable, len(varIDs)) - for i := range varIDs { - r := system.newInternalVariable() - _, vID, _ := r.LinExp[0].Unpack() - varIDs[i] = vID - res[i] = r - } - - ch := &compiled.Hint{ID: f.UUID(), Inputs: hintInputs, Wires: varIDs} - for _, vID := range varIDs { - system.MHints[vID] = ch - } - - return res, nil -} - -// ToVariable will return (and allocate if neccesary) a frontend.Variable from given value -// -// if input is already a frontend.Variable, does nothing -// else, attempts to convert input to a big.Int (see utils.FromInterface) and returns a ToVariable frontend.Variable -// -// a ToVariable frontend.Variable does NOT necessary allocate a frontend.Variable in the ConstraintSystem -// it is in the form ONE_WIRE * coeff -func (system *r1CS) ToVariable(input interface{}) frontend.Variable { - - switch t := input.(type) { - case compiled.Variable: - t.AssertIsSet() - return t - default: - n := utils.FromInterface(t) - if n.IsUint64() && n.Uint64() == 1 { - return system.one() - } - r := system.one() - r.LinExp[0] = system.setCoeff(r.LinExp[0], &n) - return r - } -} - -// ToSymbols return frontend.Variable corresponding to inputs and the total size of the linear expressions -func (system *r1CS) ToSymbols(in ...frontend.Variable) ([]compiled.Variable, int) { - r := make([]compiled.Variable, 0, len(in)) - s := 0 - e := func(i frontend.Variable) { - v := system.ToVariable(i).(compiled.Variable) - r = append(r, v) - s += len(v.LinExp) - } - // e(i1) - // e(i2) - for i := 0; i < len(in); i++ { - e(in[i]) - } - return r, s -} - // returns -le, the result is a copy -func (system *r1CS) negateLinExp(l []compiled.Term) []compiled.Term { +func (system *compiler) negateLinExp(l []compiled.Term) []compiled.Term { res := make([]compiled.Term, len(l)) var lambda big.Int for i, t := range l { @@ -777,6 +632,6 @@ func (system *r1CS) negateLinExp(l []compiled.Term) []compiled.Term { return res } -func (system *r1CS) Compiler() frontend.Compiler { +func (system *compiler) Compiler() frontend.Compiler { return system } diff --git a/frontend/cs/r1cs/assertions.go b/frontend/cs/r1cs/api_assertions.go similarity index 83% rename from frontend/cs/r1cs/assertions.go rename to frontend/cs/r1cs/api_assertions.go index 6e1c6661f9..ea2f4f630c 100644 --- a/frontend/cs/r1cs/assertions.go +++ b/frontend/cs/r1cs/api_assertions.go @@ -26,10 +26,10 @@ import ( ) // AssertIsEqual adds an assertion in the constraint system (i1 == i2) -func (system *r1CS) AssertIsEqual(i1, i2 frontend.Variable) { +func (system *compiler) AssertIsEqual(i1, i2 frontend.Variable) { // encoded 1 * i1 == i2 - r := system.ToVariable(i1).(compiled.Variable) - o := system.ToVariable(i2).(compiled.Variable) + r := system.toVariable(i1).(compiled.Variable) + o := system.toVariable(i2).(compiled.Variable) debug := system.AddDebugInfo("assertIsEqual", r, " == ", o) @@ -37,14 +37,14 @@ func (system *r1CS) AssertIsEqual(i1, i2 frontend.Variable) { } // AssertIsDifferent constrain i1 and i2 to be different -func (system *r1CS) AssertIsDifferent(i1, i2 frontend.Variable) { +func (system *compiler) AssertIsDifferent(i1, i2 frontend.Variable) { system.Inverse(system.Sub(i1, i2)) } // AssertIsBoolean adds an assertion in the constraint system (v == 0 ∥ v == 1) -func (system *r1CS) AssertIsBoolean(i1 frontend.Variable) { +func (system *compiler) AssertIsBoolean(i1 frontend.Variable) { - vars, _ := system.ToSymbols(i1) + vars, _ := system.toVariables(i1) v := vars[0] if c, ok := system.ConstantValue(v); ok { @@ -61,7 +61,7 @@ func (system *r1CS) AssertIsBoolean(i1 frontend.Variable) { debug := system.AddDebugInfo("assertIsBoolean", v, " == (0|1)") - o := system.ToVariable(0) + o := system.toVariable(0) // ensure v * (1 - v) == 0 _v := system.Sub(1, v) @@ -74,8 +74,8 @@ func (system *r1CS) AssertIsBoolean(i1 frontend.Variable) { // // derived from: // https://github.com/zcash/zips/blob/main/protocol/protocol.pdf -func (system *r1CS) AssertIsLessOrEqual(_v frontend.Variable, bound frontend.Variable) { - v, _ := system.ToSymbols(_v) +func (system *compiler) AssertIsLessOrEqual(_v frontend.Variable, bound frontend.Variable) { + v, _ := system.toVariables(_v) switch b := bound.(type) { case compiled.Variable: @@ -87,7 +87,7 @@ func (system *r1CS) AssertIsLessOrEqual(_v frontend.Variable, bound frontend.Var } -func (system *r1CS) mustBeLessOrEqVar(a, bound compiled.Variable) { +func (system *compiler) mustBeLessOrEqVar(a, bound compiled.Variable) { debug := system.AddDebugInfo("mustBeLessOrEq", a, " <= ", bound) nbBits := system.BitLen() @@ -96,9 +96,9 @@ func (system *r1CS) mustBeLessOrEqVar(a, bound compiled.Variable) { boundBits := system.ToBinary(bound, nbBits) p := make([]frontend.Variable, nbBits+1) - p[nbBits] = system.ToVariable(1) + p[nbBits] = system.toVariable(1) - zero := system.ToVariable(0) + zero := system.toVariable(0) for i := nbBits - 1; i >= 0; i-- { @@ -128,7 +128,7 @@ func (system *r1CS) mustBeLessOrEqVar(a, bound compiled.Variable) { } -func (system *r1CS) mustBeLessOrEqCst(a compiled.Variable, bound big.Int) { +func (system *compiler) mustBeLessOrEqCst(a compiled.Variable, bound big.Int) { nbBits := system.BitLen() @@ -141,7 +141,7 @@ func (system *r1CS) mustBeLessOrEqCst(a compiled.Variable, bound big.Int) { } // debug info - debug := system.AddDebugInfo("mustBeLessOrEq", a, " <= ", system.ToVariable(bound)) + debug := system.AddDebugInfo("mustBeLessOrEq", a, " <= ", system.toVariable(bound)) // note that at this stage, we didn't boolean-constraint these new variables yet // (as opposed to ToBinary) @@ -158,7 +158,7 @@ func (system *r1CS) mustBeLessOrEqCst(a compiled.Variable, bound big.Int) { p := make([]frontend.Variable, nbBits+1) // p[i] == 1 → a[j] == c[j] for all j ⩾ i - p[nbBits] = system.ToVariable(1) + p[nbBits] = system.toVariable(1) for i := nbBits - 1; i >= t; i-- { if bound.Bit(i) == 0 { @@ -174,7 +174,7 @@ func (system *r1CS) mustBeLessOrEqCst(a compiled.Variable, bound big.Int) { l := system.Sub(1, p[i+1]) l = system.Sub(l, aBits[i]) - system.addConstraint(newR1C(l, aBits[i], system.ToVariable(0)), debug) + system.addConstraint(newR1C(l, aBits[i], system.toVariable(0)), debug) system.MarkBoolean(aBits[i].(compiled.Variable)) } else { system.AssertIsBoolean(aBits[i]) diff --git a/frontend/cs/r1cs/compiler.go b/frontend/cs/r1cs/compiler.go new file mode 100644 index 0000000000..4b6b14c5e6 --- /dev/null +++ b/frontend/cs/r1cs/compiler.go @@ -0,0 +1,721 @@ +/* +Copyright © 2020 ConsenSys + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package r1cs + +import ( + "errors" + "fmt" + "math/big" + "path/filepath" + "reflect" + "runtime" + "sort" + "strconv" + "strings" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/compiled" + "github.com/consensys/gnark/frontend/cs" + "github.com/consensys/gnark/frontend/schema" + bls12377r1cs "github.com/consensys/gnark/internal/backend/bls12-377/cs" + bls12381r1cs "github.com/consensys/gnark/internal/backend/bls12-381/cs" + bls24315r1cs "github.com/consensys/gnark/internal/backend/bls24-315/cs" + bn254r1cs "github.com/consensys/gnark/internal/backend/bn254/cs" + bw6633r1cs "github.com/consensys/gnark/internal/backend/bw6-633/cs" + bw6761r1cs "github.com/consensys/gnark/internal/backend/bw6-761/cs" + "github.com/consensys/gnark/internal/utils" +) + +// NewCompiler returns a new R1CS compiler +func NewCompiler(curve ecc.ID) (frontend.Builder, error) { + return newCompiler(curve), nil +} + +type compiler struct { + compiled.ConstraintSystem + Constraints []compiled.R1C + + st cs.CoeffTable + + // map for recording boolean constrained variables (to not constrain them twice) + mtBooleans map[uint64][]compiled.LinearExpression +} + +// initialCapacity has quite some impact on frontend performance, especially on large circuits size +// we may want to add build tags to tune that +func newCompiler(curveID ecc.ID, initialCapacity ...int) *compiler { + capacity := 0 + if len(initialCapacity) > 0 { + capacity = initialCapacity[0] + } + system := compiler{ + ConstraintSystem: compiled.ConstraintSystem{ + + MDebug: make(map[int]int), + MHints: make(map[int]*compiled.Hint), + }, + Constraints: make([]compiled.R1C, 0, capacity), + st: cs.NewCoeffTable(), + mtBooleans: make(map[uint64][]compiled.LinearExpression), + } + + system.st.Coeffs[compiled.CoeffIdZero].SetInt64(0) + system.st.Coeffs[compiled.CoeffIdOne].SetInt64(1) + system.st.Coeffs[compiled.CoeffIdTwo].SetInt64(2) + system.st.Coeffs[compiled.CoeffIdMinusOne].SetInt64(-1) + + system.st.CoeffsIDsInt64[0] = compiled.CoeffIdZero + system.st.CoeffsIDsInt64[1] = compiled.CoeffIdOne + system.st.CoeffsIDsInt64[2] = compiled.CoeffIdTwo + system.st.CoeffsIDsInt64[-1] = compiled.CoeffIdMinusOne + + system.Public = make([]string, 1) + system.Secret = make([]string, 0) + + // by default the circuit is given a public wire equal to 1 + system.Public[0] = "one" + + system.CurveID = curveID + + return &system +} + +// newInternalVariable creates a new wire, appends it on the list of wires of the circuit, sets +// the wire's id to the number of wires, and returns it +func (system *compiler) newInternalVariable() compiled.Variable { + idx := system.NbInternalVariables + system.NbInternalVariables++ + return compiled.Variable{ + LinExp: compiled.LinearExpression{compiled.Pack(idx, compiled.CoeffIdOne, schema.Internal)}, + } +} + +// AddPublicVariable creates a new public Variable +func (system *compiler) AddPublicVariable(name string) frontend.Variable { + if system.Schema != nil { + panic("do not call AddPublicVariable in circuit.Define()") + } + idx := len(system.Public) + system.Public = append(system.Public, name) + res := compiled.Variable{ + LinExp: compiled.LinearExpression{compiled.Pack(idx, compiled.CoeffIdOne, schema.Public)}, + } + return res +} + +// AddSecretVariable creates a new secret Variable +func (system *compiler) AddSecretVariable(name string) frontend.Variable { + if system.Schema != nil { + panic("do not call AddSecretVariable in circuit.Define()") + } + idx := len(system.Secret) + system.Secret = append(system.Secret, name) + res := compiled.Variable{ + LinExp: compiled.LinearExpression{compiled.Pack(idx, compiled.CoeffIdOne, schema.Secret)}, + } + return res +} + +func (system *compiler) one() compiled.Variable { + return compiled.Variable{ + LinExp: compiled.LinearExpression{compiled.Pack(0, compiled.CoeffIdOne, schema.Public)}, + } +} + +// reduces redundancy in linear expression +// It factorizes Variable that appears multiple times with != coeff Ids +// To ensure the determinism in the compile process, Variables are stored as public∥secret∥internal∥unset +// for each visibility, the Variables are sorted from lowest ID to highest ID +func (system *compiler) reduce(l compiled.Variable) compiled.Variable { + // ensure our linear expression is sorted, by visibility and by Variable ID + if !sort.IsSorted(l.LinExp) { // may not help + sort.Sort(l.LinExp) + } + + mod := system.CurveID.Info().Fr.Modulus() + c := new(big.Int) + for i := 1; i < len(l.LinExp); i++ { + pcID, pvID, pVis := l.LinExp[i-1].Unpack() + ccID, cvID, cVis := l.LinExp[i].Unpack() + if pVis == cVis && pvID == cvID { + // we have redundancy + c.Add(&system.st.Coeffs[pcID], &system.st.Coeffs[ccID]) + c.Mod(c, mod) + l.LinExp[i-1].SetCoeffID(system.st.CoeffID(c)) + l.LinExp = append(l.LinExp[:i], l.LinExp[i+1:]...) + i-- + } + } + return l +} + +// newR1C clones the linear expression associated with the Variables (to avoid offseting the ID multiple time) +// and return a R1C +func newR1C(_l, _r, _o frontend.Variable) compiled.R1C { + l := _l.(compiled.Variable) + r := _r.(compiled.Variable) + o := _o.(compiled.Variable) + + // interestingly, this is key to groth16 performance. + // l * r == r * l == o + // but the "l" linear expression is going to end up in the A matrix + // the "r" linear expression is going to end up in the B matrix + // the less Variable we have appearing in the B matrix, the more likely groth16.Setup + // is going to produce infinity points in pk.G1.B and pk.G2.B, which will speed up proving time + if len(l.LinExp) > len(r.LinExp) { + l, r = r, l + } + + return compiled.R1C{L: l.Clone(), R: r.Clone(), O: o.Clone()} +} + +func (system *compiler) addConstraint(r1c compiled.R1C, debugID ...int) { + system.Constraints = append(system.Constraints, r1c) + if len(debugID) > 0 { + system.MDebug[len(system.Constraints)-1] = debugID[0] + } +} + +// Term packs a Variable and a coeff in a Term and returns it. +// func (system *R1CSRefactor) setCoeff(v Variable, coeff *big.Int) Term { +func (system *compiler) setCoeff(v compiled.Term, coeff *big.Int) compiled.Term { + _, vID, vVis := v.Unpack() + return compiled.Pack(vID, system.st.CoeffID(coeff), vVis) +} + +// MarkBoolean sets (but do not **constraint**!) v to be boolean +// This is useful in scenarios where a variable is known to be boolean through a constraint +// that is not api.AssertIsBoolean. If v is a constant, this is a no-op. +func (system *compiler) MarkBoolean(v frontend.Variable) { + if _, ok := system.ConstantValue(v); ok { + return + } + // v is a linear expression + l := v.(compiled.Variable).LinExp + if !sort.IsSorted(l) { + sort.Sort(l) + } + + key := l.HashCode() + list := system.mtBooleans[key] + list = append(list, l) + system.mtBooleans[key] = list +} + +// IsBoolean returns true if given variable was marked as boolean in the compiler (see MarkBoolean) +// Use with care; variable may not have been **constrained** to be boolean +// This returns true if the v is a constant and v == 0 || v == 1. +func (system *compiler) IsBoolean(v frontend.Variable) bool { + if b, ok := system.ConstantValue(v); ok { + return b.IsUint64() && b.Uint64() <= 1 + } + // v is a linear expression + l := v.(compiled.Variable).LinExp + if !sort.IsSorted(l) { + sort.Sort(l) + } + + key := l.HashCode() + list, ok := system.mtBooleans[key] + if !ok { + return false + } + + for _, v := range list { + if v.Equal(l) { + return true + } + } + return false +} + +// checkVariables perform post compilation checks on the Variables +// +// 1. checks that all user inputs are referenced in at least one constraint +// 2. checks that all hints are constrained +func (system *compiler) checkVariables() error { + + // TODO @gbotrel add unit test for that. + + cptSecret := len(system.Secret) + cptPublic := len(system.Public) + cptHints := len(system.MHints) + + secretConstrained := make([]bool, cptSecret) + publicConstrained := make([]bool, cptPublic) + // one wire does not need to be constrained + publicConstrained[0] = true + cptPublic-- + + mHintsConstrained := make(map[int]bool) + + // for each constraint, we check the linear expressions and mark our inputs / hints as constrained + processLinearExpression := func(l compiled.Variable) { + for _, t := range l.LinExp { + if t.CoeffID() == compiled.CoeffIdZero { + // ignore zero coefficient, as it does not constraint the Variable + // though, we may want to flag that IF the Variable doesn't appear else where + continue + } + visibility := t.VariableVisibility() + vID := t.WireID() + + switch visibility { + case schema.Public: + if vID != 0 && !publicConstrained[vID] { + publicConstrained[vID] = true + cptPublic-- + } + case schema.Secret: + if !secretConstrained[vID] { + secretConstrained[vID] = true + cptSecret-- + } + case schema.Internal: + if _, ok := system.MHints[vID]; !mHintsConstrained[vID] && ok { + mHintsConstrained[vID] = true + cptHints-- + } + } + } + } + for _, r1c := range system.Constraints { + processLinearExpression(r1c.L) + processLinearExpression(r1c.R) + processLinearExpression(r1c.O) + + if cptHints|cptSecret|cptPublic == 0 { + return nil // we can stop. + } + + } + + // something is a miss, we build the error string + var sbb strings.Builder + if cptSecret != 0 { + sbb.WriteString(strconv.Itoa(cptSecret)) + sbb.WriteString(" unconstrained secret input(s):") + sbb.WriteByte('\n') + for i := 0; i < len(secretConstrained) && cptSecret != 0; i++ { + if !secretConstrained[i] { + sbb.WriteString(system.Secret[i]) + sbb.WriteByte('\n') + cptSecret-- + } + } + sbb.WriteByte('\n') + } + + if cptPublic != 0 { + sbb.WriteString(strconv.Itoa(cptPublic)) + sbb.WriteString(" unconstrained public input(s):") + sbb.WriteByte('\n') + for i := 0; i < len(publicConstrained) && cptPublic != 0; i++ { + if !publicConstrained[i] { + sbb.WriteString(system.Public[i]) + sbb.WriteByte('\n') + cptPublic-- + } + } + sbb.WriteByte('\n') + } + + if cptHints != 0 { + sbb.WriteString(strconv.Itoa(cptHints)) + sbb.WriteString(" unconstrained hints") + sbb.WriteByte('\n') + // TODO we may add more debug info here → idea, in NewHint, take the debug stack, and store in the hint map some + // debugInfo to find where a hint was declared (and not constrained) + } + return errors.New(sbb.String()) +} + +var tVariable reflect.Type + +func init() { + tVariable = reflect.ValueOf(struct{ A frontend.Variable }{}).FieldByName("A").Type() +} + +// Compile constructs a rank-1 constraint sytem +func (cs *compiler) Compile(opt frontend.CompileConfig) (frontend.CompiledConstraintSystem, error) { + + // ensure all inputs and hints are constrained + if !opt.IgnoreUnconstrainedInputs { + if err := cs.checkVariables(); err != nil { + return nil, err + } + } + + // wires = public wires | secret wires | internal wires + + // setting up the result + res := compiled.R1CS{ + ConstraintSystem: cs.ConstraintSystem, + Constraints: cs.Constraints, + } + res.NbPublicVariables = len(cs.Public) + res.NbSecretVariables = len(cs.Secret) + + // for Logs, DebugInfo and hints the only thing that will change + // is that ID of the wires will be offseted to take into account the final wire vector ordering + // that is: public wires | secret wires | internal wires + + // offset variable ID depeneding on visibility + shiftVID := func(oldID int, visibility schema.Visibility) int { + switch visibility { + case schema.Internal: + return oldID + res.NbPublicVariables + res.NbSecretVariables + case schema.Public: + return oldID + case schema.Secret: + return oldID + res.NbPublicVariables + } + return oldID + } + + // we just need to offset our ids, such that wires = [ public wires | secret wires | internal wires ] + offsetIDs := func(l compiled.LinearExpression) { + for j := 0; j < len(l); j++ { + _, vID, visibility := l[j].Unpack() + l[j].SetWireID(shiftVID(vID, visibility)) + } + } + + for i := 0; i < len(res.Constraints); i++ { + offsetIDs(res.Constraints[i].L.LinExp) + offsetIDs(res.Constraints[i].R.LinExp) + offsetIDs(res.Constraints[i].O.LinExp) + } + + // we need to offset the ids in the hints + shiftedMap := make(map[int]*compiled.Hint) + + // we need to offset the ids in the hints +HINTLOOP: + for _, hint := range cs.MHints { + ws := make([]int, len(hint.Wires)) + // we set for all outputs in shiftedMap. If one shifted output + // is in shiftedMap, then all are + for i, vID := range hint.Wires { + ws[i] = shiftVID(vID, schema.Internal) + if _, ok := shiftedMap[ws[i]]; i == 0 && ok { + continue HINTLOOP + } + } + inputs := make([]interface{}, len(hint.Inputs)) + copy(inputs, hint.Inputs) + for j := 0; j < len(inputs); j++ { + switch t := inputs[j].(type) { + case compiled.Variable: + tmp := make(compiled.LinearExpression, len(t.LinExp)) + copy(tmp, t.LinExp) + offsetIDs(tmp) + inputs[j] = tmp + case compiled.LinearExpression: + tmp := make(compiled.LinearExpression, len(t)) + copy(tmp, t) + offsetIDs(tmp) + inputs[j] = tmp + default: + inputs[j] = t + } + } + ch := &compiled.Hint{ID: hint.ID, Inputs: inputs, Wires: ws} + for _, vID := range ws { + shiftedMap[vID] = ch + } + } + res.MHints = shiftedMap + + // we need to offset the ids in Logs & DebugInfo + for i := 0; i < len(cs.Logs); i++ { + + for j := 0; j < len(res.Logs[i].ToResolve); j++ { + _, vID, visibility := res.Logs[i].ToResolve[j].Unpack() + res.Logs[i].ToResolve[j].SetWireID(shiftVID(vID, visibility)) + } + } + for i := 0; i < len(cs.DebugInfo); i++ { + for j := 0; j < len(res.DebugInfo[i].ToResolve); j++ { + _, vID, visibility := res.DebugInfo[i].ToResolve[j].Unpack() + res.DebugInfo[i].ToResolve[j].SetWireID(shiftVID(vID, visibility)) + } + } + + // build levels + res.Levels = buildLevels(res) + + switch cs.CurveID { + case ecc.BLS12_377: + return bls12377r1cs.NewR1CS(res, cs.st.Coeffs), nil + case ecc.BLS12_381: + return bls12381r1cs.NewR1CS(res, cs.st.Coeffs), nil + case ecc.BN254: + return bn254r1cs.NewR1CS(res, cs.st.Coeffs), nil + case ecc.BW6_761: + return bw6761r1cs.NewR1CS(res, cs.st.Coeffs), nil + case ecc.BW6_633: + return bw6633r1cs.NewR1CS(res, cs.st.Coeffs), nil + case ecc.BLS24_315: + return bls24315r1cs.NewR1CS(res, cs.st.Coeffs), nil + default: + panic("not implemtented") + } +} + +func (cs *compiler) SetSchema(s *schema.Schema) { + if cs.Schema != nil { + panic("SetSchema called multiple times") + } + cs.Schema = s +} + +func buildLevels(ccs compiled.R1CS) [][]int { + + b := levelBuilder{ + mWireToNode: make(map[int]int, ccs.NbInternalVariables), // at which node we resolved which wire + nodeLevels: make([]int, len(ccs.Constraints)), // level of a node + mLevels: make(map[int]int), // level counts + ccs: ccs, + nbInputs: ccs.NbPublicVariables + ccs.NbSecretVariables, + } + + // for each constraint, we're going to find its direct dependencies + // that is, wires (solved by previous constraints) on which it depends + // each of these dependencies is tagged with a level + // current constraint will be tagged with max(level) + 1 + for cID, c := range ccs.Constraints { + + b.nodeLevel = 0 + + b.processLE(c.L.LinExp, cID) + b.processLE(c.R.LinExp, cID) + b.processLE(c.O.LinExp, cID) + b.nodeLevels[cID] = b.nodeLevel + b.mLevels[b.nodeLevel]++ + + } + + levels := make([][]int, len(b.mLevels)) + for i := 0; i < len(levels); i++ { + // allocate memory + levels[i] = make([]int, 0, b.mLevels[i]) + } + + for n, l := range b.nodeLevels { + levels[l] = append(levels[l], n) + } + + return levels +} + +type levelBuilder struct { + ccs compiled.R1CS + nbInputs int + + mWireToNode map[int]int // at which node we resolved which wire + nodeLevels []int // level per node + mLevels map[int]int // number of constraint per level + + nodeLevel int // current level +} + +func (b *levelBuilder) processLE(l compiled.LinearExpression, cID int) { + + for _, t := range l { + wID := t.WireID() + if wID < b.nbInputs { + // it's a input, we ignore it + continue + } + + // if we know a which constraint solves this wire, then it's a dependency + n, ok := b.mWireToNode[wID] + if ok { + if n != cID { // can happen with hints... + // we add a dependency, check if we need to increment our current level + if b.nodeLevels[n] >= b.nodeLevel { + b.nodeLevel = b.nodeLevels[n] + 1 // we are at the next level at least since we depend on it + } + } + continue + } + + // check if it's a hint and mark all the output wires + if h, ok := b.ccs.MHints[wID]; ok { + + for _, in := range h.Inputs { + switch t := in.(type) { + case compiled.Variable: + b.processLE(t.LinExp, cID) + case compiled.LinearExpression: + b.processLE(t, cID) + case compiled.Term: + b.processLE(compiled.LinearExpression{t}, cID) + } + } + + for _, hwid := range h.Wires { + b.mWireToNode[hwid] = cID + } + continue + } + + // mark this wire solved by current node + b.mWireToNode[wID] = cID + } +} + +// ConstantValue returns the big.Int value of v. +// Will panic if v.IsConstant() == false +func (system *compiler) ConstantValue(v frontend.Variable) (*big.Int, bool) { + if _v, ok := v.(compiled.Variable); ok { + _v.AssertIsSet() + + if len(_v.LinExp) != 1 { + return nil, false + } + cID, vID, visibility := _v.LinExp[0].Unpack() + if !(vID == 0 && visibility == schema.Public) { + return nil, false + } + return new(big.Int).Set(&system.st.Coeffs[cID]), true + } + r := utils.FromInterface(v) + return &r, true +} + +func (system *compiler) Backend() backend.ID { + return backend.GROTH16 +} + +// toVariable will return (and allocate if neccesary) a compiled.Variable from given value +// +// if input is already a compiled.Variable, does nothing +// else, attempts to convert input to a big.Int (see utils.FromInterface) and returns a toVariable compiled.Variable +func (system *compiler) toVariable(input interface{}) frontend.Variable { + + switch t := input.(type) { + case compiled.Variable: + t.AssertIsSet() + return t + default: + n := utils.FromInterface(t) + if n.IsUint64() && n.Uint64() == 1 { + return system.one() + } + r := system.one() + r.LinExp[0] = system.setCoeff(r.LinExp[0], &n) + return r + } +} + +// toVariables return frontend.Variable corresponding to inputs and the total size of the linear expressions +func (system *compiler) toVariables(in ...frontend.Variable) ([]compiled.Variable, int) { + r := make([]compiled.Variable, 0, len(in)) + s := 0 + e := func(i frontend.Variable) { + v := system.toVariable(i).(compiled.Variable) + r = append(r, v) + s += len(v.LinExp) + } + // e(i1) + // e(i2) + for i := 0; i < len(in); i++ { + e(in[i]) + } + return r, s +} + +// Tag creates a tag at a given place in a circuit. The state of the tag may contain informations needed to +// measure constraints, variables and coefficients creations through AddCounter +func (system *compiler) Tag(name string) frontend.Tag { + _, file, line, _ := runtime.Caller(1) + + return frontend.Tag{ + Name: fmt.Sprintf("%s[%s:%d]", name, filepath.Base(file), line), + VID: system.NbInternalVariables, + CID: len(system.Constraints), + } +} + +// AddCounter measures the number of constraints, variables and coefficients created between two tags +func (system *compiler) AddCounter(from, to frontend.Tag) { + system.Counters = append(system.Counters, compiled.Counter{ + From: from.Name, + To: to.Name, + NbVariables: to.VID - from.VID, + NbConstraints: to.CID - from.CID, + CurveID: system.CurveID, + BackendID: backend.GROTH16, + }) +} + +// NewHint initializes internal variables whose value will be evaluated using +// the provided hint function at run time from the inputs. Inputs must be either +// variables or convertible to *big.Int. The function returns an error if the +// number of inputs is not compatible with f. +// +// The hint function is provided at the proof creation time and is not embedded +// into the circuit. From the backend point of view, the variable returned by +// the hint function is equivalent to the user-supplied witness, but its actual +// value is assigned by the solver, not the caller. +// +// No new constraints are added to the newly created wire and must be added +// manually in the circuit. Failing to do so leads to solver failure. +func (system *compiler) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { + + if nbOutputs <= 0 { + return nil, fmt.Errorf("hint function must return at least one output") + } + hintInputs := make([]interface{}, len(inputs)) + + // ensure inputs are set and pack them in a []uint64 + for i, in := range inputs { + switch t := in.(type) { + case compiled.Variable: + tmp := t.Clone() + hintInputs[i] = tmp.LinExp + case compiled.LinearExpression: + tmp := make(compiled.LinearExpression, len(t)) + copy(tmp, t) + hintInputs[i] = tmp + default: + hintInputs[i] = utils.FromInterface(t) + } + } + + // prepare wires + varIDs := make([]int, nbOutputs) + res := make([]frontend.Variable, len(varIDs)) + for i := range varIDs { + r := system.newInternalVariable() + _, vID, _ := r.LinExp[0].Unpack() + varIDs[i] = vID + res[i] = r + } + + ch := &compiled.Hint{ID: f.UUID(), Inputs: hintInputs, Wires: varIDs} + for _, vID := range varIDs { + system.MHints[vID] = ch + } + + return res, nil +} diff --git a/frontend/cs/r1cs/conversion.go b/frontend/cs/r1cs/conversion.go deleted file mode 100644 index f797b29d85..0000000000 --- a/frontend/cs/r1cs/conversion.go +++ /dev/null @@ -1,260 +0,0 @@ -/* -Copyright © 2020 ConsenSys - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package r1cs - -import ( - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/compiled" - "github.com/consensys/gnark/frontend/schema" - bls12377r1cs "github.com/consensys/gnark/internal/backend/bls12-377/cs" - bls12381r1cs "github.com/consensys/gnark/internal/backend/bls12-381/cs" - bls24315r1cs "github.com/consensys/gnark/internal/backend/bls24-315/cs" - bn254r1cs "github.com/consensys/gnark/internal/backend/bn254/cs" - bw6633r1cs "github.com/consensys/gnark/internal/backend/bw6-633/cs" - bw6761r1cs "github.com/consensys/gnark/internal/backend/bw6-761/cs" -) - -// Compile constructs a rank-1 constraint sytem -func (cs *r1CS) Compile(opt frontend.CompileConfig) (frontend.CompiledConstraintSystem, error) { - - // ensure all inputs and hints are constrained - if !opt.IgnoreUnconstrainedInputs { - if err := cs.checkVariables(); err != nil { - return nil, err - } - } - - // wires = public wires | secret wires | internal wires - - // setting up the result - res := compiled.R1CS{ - ConstraintSystem: cs.ConstraintSystem, - Constraints: cs.Constraints, - } - res.NbPublicVariables = len(cs.Public) - res.NbSecretVariables = len(cs.Secret) - - // for Logs, DebugInfo and hints the only thing that will change - // is that ID of the wires will be offseted to take into account the final wire vector ordering - // that is: public wires | secret wires | internal wires - - // offset variable ID depeneding on visibility - shiftVID := func(oldID int, visibility schema.Visibility) int { - switch visibility { - case schema.Internal: - return oldID + res.NbPublicVariables + res.NbSecretVariables - case schema.Public: - return oldID - case schema.Secret: - return oldID + res.NbPublicVariables - } - return oldID - } - - // we just need to offset our ids, such that wires = [ public wires | secret wires | internal wires ] - offsetIDs := func(l compiled.LinearExpression) { - for j := 0; j < len(l); j++ { - _, vID, visibility := l[j].Unpack() - l[j].SetWireID(shiftVID(vID, visibility)) - } - } - - for i := 0; i < len(res.Constraints); i++ { - offsetIDs(res.Constraints[i].L.LinExp) - offsetIDs(res.Constraints[i].R.LinExp) - offsetIDs(res.Constraints[i].O.LinExp) - } - - // we need to offset the ids in the hints - shiftedMap := make(map[int]*compiled.Hint) - - // we need to offset the ids in the hints -HINTLOOP: - for _, hint := range cs.MHints { - ws := make([]int, len(hint.Wires)) - // we set for all outputs in shiftedMap. If one shifted output - // is in shiftedMap, then all are - for i, vID := range hint.Wires { - ws[i] = shiftVID(vID, schema.Internal) - if _, ok := shiftedMap[ws[i]]; i == 0 && ok { - continue HINTLOOP - } - } - inputs := make([]interface{}, len(hint.Inputs)) - copy(inputs, hint.Inputs) - for j := 0; j < len(inputs); j++ { - switch t := inputs[j].(type) { - case compiled.Variable: - tmp := make(compiled.LinearExpression, len(t.LinExp)) - copy(tmp, t.LinExp) - offsetIDs(tmp) - inputs[j] = tmp - case compiled.LinearExpression: - tmp := make(compiled.LinearExpression, len(t)) - copy(tmp, t) - offsetIDs(tmp) - inputs[j] = tmp - default: - inputs[j] = t - } - } - ch := &compiled.Hint{ID: hint.ID, Inputs: inputs, Wires: ws} - for _, vID := range ws { - shiftedMap[vID] = ch - } - } - res.MHints = shiftedMap - - // we need to offset the ids in Logs & DebugInfo - for i := 0; i < len(cs.Logs); i++ { - - for j := 0; j < len(res.Logs[i].ToResolve); j++ { - _, vID, visibility := res.Logs[i].ToResolve[j].Unpack() - res.Logs[i].ToResolve[j].SetWireID(shiftVID(vID, visibility)) - } - } - for i := 0; i < len(cs.DebugInfo); i++ { - for j := 0; j < len(res.DebugInfo[i].ToResolve); j++ { - _, vID, visibility := res.DebugInfo[i].ToResolve[j].Unpack() - res.DebugInfo[i].ToResolve[j].SetWireID(shiftVID(vID, visibility)) - } - } - - // build levels - res.Levels = buildLevels(res) - - switch cs.CurveID { - case ecc.BLS12_377: - return bls12377r1cs.NewR1CS(res, cs.st.Coeffs), nil - case ecc.BLS12_381: - return bls12381r1cs.NewR1CS(res, cs.st.Coeffs), nil - case ecc.BN254: - return bn254r1cs.NewR1CS(res, cs.st.Coeffs), nil - case ecc.BW6_761: - return bw6761r1cs.NewR1CS(res, cs.st.Coeffs), nil - case ecc.BW6_633: - return bw6633r1cs.NewR1CS(res, cs.st.Coeffs), nil - case ecc.BLS24_315: - return bls24315r1cs.NewR1CS(res, cs.st.Coeffs), nil - default: - panic("not implemtented") - } -} - -func (cs *r1CS) SetSchema(s *schema.Schema) { - if cs.Schema != nil { - panic("SetSchema called multiple times") - } - cs.Schema = s -} - -func buildLevels(ccs compiled.R1CS) [][]int { - - b := levelBuilder{ - mWireToNode: make(map[int]int, ccs.NbInternalVariables), // at which node we resolved which wire - nodeLevels: make([]int, len(ccs.Constraints)), // level of a node - mLevels: make(map[int]int), // level counts - ccs: ccs, - nbInputs: ccs.NbPublicVariables + ccs.NbSecretVariables, - } - - // for each constraint, we're going to find its direct dependencies - // that is, wires (solved by previous constraints) on which it depends - // each of these dependencies is tagged with a level - // current constraint will be tagged with max(level) + 1 - for cID, c := range ccs.Constraints { - - b.nodeLevel = 0 - - b.processLE(c.L.LinExp, cID) - b.processLE(c.R.LinExp, cID) - b.processLE(c.O.LinExp, cID) - b.nodeLevels[cID] = b.nodeLevel - b.mLevels[b.nodeLevel]++ - - } - - levels := make([][]int, len(b.mLevels)) - for i := 0; i < len(levels); i++ { - // allocate memory - levels[i] = make([]int, 0, b.mLevels[i]) - } - - for n, l := range b.nodeLevels { - levels[l] = append(levels[l], n) - } - - return levels -} - -type levelBuilder struct { - ccs compiled.R1CS - nbInputs int - - mWireToNode map[int]int // at which node we resolved which wire - nodeLevels []int // level per node - mLevels map[int]int // number of constraint per level - - nodeLevel int // current level -} - -func (b *levelBuilder) processLE(l compiled.LinearExpression, cID int) { - - for _, t := range l { - wID := t.WireID() - if wID < b.nbInputs { - // it's a input, we ignore it - continue - } - - // if we know a which constraint solves this wire, then it's a dependency - n, ok := b.mWireToNode[wID] - if ok { - if n != cID { // can happen with hints... - // we add a dependency, check if we need to increment our current level - if b.nodeLevels[n] >= b.nodeLevel { - b.nodeLevel = b.nodeLevels[n] + 1 // we are at the next level at least since we depend on it - } - } - continue - } - - // check if it's a hint and mark all the output wires - if h, ok := b.ccs.MHints[wID]; ok { - - for _, in := range h.Inputs { - switch t := in.(type) { - case compiled.Variable: - b.processLE(t.LinExp, cID) - case compiled.LinearExpression: - b.processLE(t, cID) - case compiled.Term: - b.processLE(compiled.LinearExpression{t}, cID) - } - } - - for _, hwid := range h.Wires { - b.mWireToNode[hwid] = cID - } - continue - } - - // mark this wire solved by current node - b.mWireToNode[wID] = cID - } -} diff --git a/frontend/cs/r1cs/r1cs.go b/frontend/cs/r1cs/r1cs.go deleted file mode 100644 index da7ff0daa0..0000000000 --- a/frontend/cs/r1cs/r1cs.go +++ /dev/null @@ -1,341 +0,0 @@ -/* -Copyright © 2021 ConsenSys Software Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package r1cs - -import ( - "errors" - "math/big" - "reflect" - "sort" - "strconv" - "strings" - - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/compiled" - "github.com/consensys/gnark/frontend/cs" - "github.com/consensys/gnark/frontend/schema" -) - -func NewCompiler(curve ecc.ID) (frontend.Builder, error) { - return newR1CS(curve), nil -} - -type r1CS struct { - compiled.ConstraintSystem - Constraints []compiled.R1C - - st cs.CoeffTable - - // map for recording boolean constrained variables (to not constrain them twice) - mtBooleans map[uint64][]compiled.LinearExpression -} - -// initialCapacity has quite some impact on frontend performance, especially on large circuits size -// we may want to add build tags to tune that -func newR1CS(curveID ecc.ID, initialCapacity ...int) *r1CS { - capacity := 0 - if len(initialCapacity) > 0 { - capacity = initialCapacity[0] - } - system := r1CS{ - ConstraintSystem: compiled.ConstraintSystem{ - - MDebug: make(map[int]int), - MHints: make(map[int]*compiled.Hint), - }, - Constraints: make([]compiled.R1C, 0, capacity), - st: cs.NewCoeffTable(), - mtBooleans: make(map[uint64][]compiled.LinearExpression), - } - - system.st.Coeffs[compiled.CoeffIdZero].SetInt64(0) - system.st.Coeffs[compiled.CoeffIdOne].SetInt64(1) - system.st.Coeffs[compiled.CoeffIdTwo].SetInt64(2) - system.st.Coeffs[compiled.CoeffIdMinusOne].SetInt64(-1) - - system.st.CoeffsIDsInt64[0] = compiled.CoeffIdZero - system.st.CoeffsIDsInt64[1] = compiled.CoeffIdOne - system.st.CoeffsIDsInt64[2] = compiled.CoeffIdTwo - system.st.CoeffsIDsInt64[-1] = compiled.CoeffIdMinusOne - - system.Public = make([]string, 1) - system.Secret = make([]string, 0) - - // by default the circuit is given a public wire equal to 1 - system.Public[0] = "one" - - system.CurveID = curveID - - return &system -} - -// newInternalVariable creates a new wire, appends it on the list of wires of the circuit, sets -// the wire's id to the number of wires, and returns it -func (system *r1CS) newInternalVariable() compiled.Variable { - idx := system.NbInternalVariables - system.NbInternalVariables++ - return compiled.Variable{ - LinExp: compiled.LinearExpression{compiled.Pack(idx, compiled.CoeffIdOne, schema.Internal)}, - } -} - -// AddPublicVariable creates a new public Variable -func (system *r1CS) AddPublicVariable(name string) frontend.Variable { - if system.Schema != nil { - panic("do not call AddPublicVariable in circuit.Define()") - } - idx := len(system.Public) - system.Public = append(system.Public, name) - res := compiled.Variable{ - LinExp: compiled.LinearExpression{compiled.Pack(idx, compiled.CoeffIdOne, schema.Public)}, - } - return res -} - -// AddSecretVariable creates a new secret Variable -func (system *r1CS) AddSecretVariable(name string) frontend.Variable { - if system.Schema != nil { - panic("do not call AddSecretVariable in circuit.Define()") - } - idx := len(system.Secret) - system.Secret = append(system.Secret, name) - res := compiled.Variable{ - LinExp: compiled.LinearExpression{compiled.Pack(idx, compiled.CoeffIdOne, schema.Secret)}, - } - return res -} - -func (system *r1CS) one() compiled.Variable { - return compiled.Variable{ - LinExp: compiled.LinearExpression{compiled.Pack(0, compiled.CoeffIdOne, schema.Public)}, - } -} - -// reduces redundancy in linear expression -// It factorizes Variable that appears multiple times with != coeff Ids -// To ensure the determinism in the compile process, Variables are stored as public∥secret∥internal∥unset -// for each visibility, the Variables are sorted from lowest ID to highest ID -func (system *r1CS) reduce(l compiled.Variable) compiled.Variable { - // ensure our linear expression is sorted, by visibility and by Variable ID - if !sort.IsSorted(l.LinExp) { // may not help - sort.Sort(l.LinExp) - } - - mod := system.CurveID.Info().Fr.Modulus() - c := new(big.Int) - for i := 1; i < len(l.LinExp); i++ { - pcID, pvID, pVis := l.LinExp[i-1].Unpack() - ccID, cvID, cVis := l.LinExp[i].Unpack() - if pVis == cVis && pvID == cvID { - // we have redundancy - c.Add(&system.st.Coeffs[pcID], &system.st.Coeffs[ccID]) - c.Mod(c, mod) - l.LinExp[i-1].SetCoeffID(system.st.CoeffID(c)) - l.LinExp = append(l.LinExp[:i], l.LinExp[i+1:]...) - i-- - } - } - return l -} - -// newR1C clones the linear expression associated with the Variables (to avoid offseting the ID multiple time) -// and return a R1C -func newR1C(_l, _r, _o frontend.Variable) compiled.R1C { - l := _l.(compiled.Variable) - r := _r.(compiled.Variable) - o := _o.(compiled.Variable) - - // interestingly, this is key to groth16 performance. - // l * r == r * l == o - // but the "l" linear expression is going to end up in the A matrix - // the "r" linear expression is going to end up in the B matrix - // the less Variable we have appearing in the B matrix, the more likely groth16.Setup - // is going to produce infinity points in pk.G1.B and pk.G2.B, which will speed up proving time - if len(l.LinExp) > len(r.LinExp) { - l, r = r, l - } - - return compiled.R1C{L: l.Clone(), R: r.Clone(), O: o.Clone()} -} - -func (system *r1CS) addConstraint(r1c compiled.R1C, debugID ...int) { - system.Constraints = append(system.Constraints, r1c) - if len(debugID) > 0 { - system.MDebug[len(system.Constraints)-1] = debugID[0] - } -} - -// Term packs a Variable and a coeff in a Term and returns it. -// func (system *R1CSRefactor) setCoeff(v Variable, coeff *big.Int) Term { -func (system *r1CS) setCoeff(v compiled.Term, coeff *big.Int) compiled.Term { - _, vID, vVis := v.Unpack() - return compiled.Pack(vID, system.st.CoeffID(coeff), vVis) -} - -// MarkBoolean sets (but do not **constraint**!) v to be boolean -// This is useful in scenarios where a variable is known to be boolean through a constraint -// that is not api.AssertIsBoolean. If v is a constant, this is a no-op. -func (system *r1CS) MarkBoolean(v frontend.Variable) { - if _, ok := system.ConstantValue(v); ok { - return - } - // v is a linear expression - l := v.(compiled.Variable).LinExp - if !sort.IsSorted(l) { - sort.Sort(l) - } - - key := l.HashCode() - list := system.mtBooleans[key] - list = append(list, l) - system.mtBooleans[key] = list -} - -// IsBoolean returns true if given variable was marked as boolean in the compiler (see MarkBoolean) -// Use with care; variable may not have been **constrained** to be boolean -// This returns true if the v is a constant and v == 0 || v == 1. -func (system *r1CS) IsBoolean(v frontend.Variable) bool { - if b, ok := system.ConstantValue(v); ok { - return b.IsUint64() && b.Uint64() <= 1 - } - // v is a linear expression - l := v.(compiled.Variable).LinExp - if !sort.IsSorted(l) { - sort.Sort(l) - } - - key := l.HashCode() - list, ok := system.mtBooleans[key] - if !ok { - return false - } - - for _, v := range list { - if v.Equal(l) { - return true - } - } - return false -} - -// checkVariables perform post compilation checks on the Variables -// -// 1. checks that all user inputs are referenced in at least one constraint -// 2. checks that all hints are constrained -func (system *r1CS) checkVariables() error { - - // TODO @gbotrel add unit test for that. - - cptSecret := len(system.Secret) - cptPublic := len(system.Public) - cptHints := len(system.MHints) - - secretConstrained := make([]bool, cptSecret) - publicConstrained := make([]bool, cptPublic) - // one wire does not need to be constrained - publicConstrained[0] = true - cptPublic-- - - mHintsConstrained := make(map[int]bool) - - // for each constraint, we check the linear expressions and mark our inputs / hints as constrained - processLinearExpression := func(l compiled.Variable) { - for _, t := range l.LinExp { - if t.CoeffID() == compiled.CoeffIdZero { - // ignore zero coefficient, as it does not constraint the Variable - // though, we may want to flag that IF the Variable doesn't appear else where - continue - } - visibility := t.VariableVisibility() - vID := t.WireID() - - switch visibility { - case schema.Public: - if vID != 0 && !publicConstrained[vID] { - publicConstrained[vID] = true - cptPublic-- - } - case schema.Secret: - if !secretConstrained[vID] { - secretConstrained[vID] = true - cptSecret-- - } - case schema.Internal: - if _, ok := system.MHints[vID]; !mHintsConstrained[vID] && ok { - mHintsConstrained[vID] = true - cptHints-- - } - } - } - } - for _, r1c := range system.Constraints { - processLinearExpression(r1c.L) - processLinearExpression(r1c.R) - processLinearExpression(r1c.O) - - if cptHints|cptSecret|cptPublic == 0 { - return nil // we can stop. - } - - } - - // something is a miss, we build the error string - var sbb strings.Builder - if cptSecret != 0 { - sbb.WriteString(strconv.Itoa(cptSecret)) - sbb.WriteString(" unconstrained secret input(s):") - sbb.WriteByte('\n') - for i := 0; i < len(secretConstrained) && cptSecret != 0; i++ { - if !secretConstrained[i] { - sbb.WriteString(system.Secret[i]) - sbb.WriteByte('\n') - cptSecret-- - } - } - sbb.WriteByte('\n') - } - - if cptPublic != 0 { - sbb.WriteString(strconv.Itoa(cptPublic)) - sbb.WriteString(" unconstrained public input(s):") - sbb.WriteByte('\n') - for i := 0; i < len(publicConstrained) && cptPublic != 0; i++ { - if !publicConstrained[i] { - sbb.WriteString(system.Public[i]) - sbb.WriteByte('\n') - cptPublic-- - } - } - sbb.WriteByte('\n') - } - - if cptHints != 0 { - sbb.WriteString(strconv.Itoa(cptHints)) - sbb.WriteString(" unconstrained hints") - sbb.WriteByte('\n') - // TODO we may add more debug info here → idea, in NewHint, take the debug stack, and store in the hint map some - // debugInfo to find where a hint was declared (and not constrained) - } - return errors.New(sbb.String()) -} - -var tVariable reflect.Type - -func init() { - tVariable = reflect.ValueOf(struct{ A frontend.Variable }{}).FieldByName("A").Type() -} diff --git a/frontend/cs/r1cs/r1cs_test.go b/frontend/cs/r1cs/r1cs_test.go index d16bb98c3a..557f2fbb96 100644 --- a/frontend/cs/r1cs/r1cs_test.go +++ b/frontend/cs/r1cs/r1cs_test.go @@ -50,7 +50,7 @@ func TestQuickSort(t *testing.T) { func TestReduce(t *testing.T) { - cs := newR1CS(ecc.BN254) + cs := newCompiler(ecc.BN254) x := cs.newInternalVariable() y := cs.newInternalVariable() z := cs.newInternalVariable() diff --git a/frontend/cs/scs/api.go b/frontend/cs/scs/api.go index d8ba0d6fd0..48816556f7 100644 --- a/frontend/cs/scs/api.go +++ b/frontend/cs/scs/api.go @@ -25,7 +25,6 @@ import ( "strconv" "strings" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/compiled" @@ -34,7 +33,7 @@ import ( ) // Add returns res = i1+i2+...in -func (system *sparseR1CS) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { +func (system *compiler) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { zero := big.NewInt(0) vars, k := system.filterConstantSum(append([]frontend.Variable{i1, i2}, in...)) @@ -54,7 +53,7 @@ func (system *sparseR1CS) Add(i1, i2 frontend.Variable, in ...frontend.Variable) } // neg returns -in -func (system *sparseR1CS) neg(in []frontend.Variable) []frontend.Variable { +func (system *compiler) neg(in []frontend.Variable) []frontend.Variable { res := make([]frontend.Variable, len(in)) @@ -65,13 +64,13 @@ func (system *sparseR1CS) neg(in []frontend.Variable) []frontend.Variable { } // Sub returns res = i1 - i2 - ...in -func (system *sparseR1CS) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { +func (system *compiler) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { r := system.neg(append([]frontend.Variable{i2}, in...)) return system.Add(i1, r[0], r[1:]...) } // Neg returns -i -func (system *sparseR1CS) Neg(i1 frontend.Variable) frontend.Variable { +func (system *compiler) Neg(i1 frontend.Variable) frontend.Variable { if n, ok := system.ConstantValue(i1); ok { n.Neg(n) // TODO shouldn't that go through variable conversion? @@ -89,7 +88,7 @@ func (system *sparseR1CS) Neg(i1 frontend.Variable) frontend.Variable { } // Mul returns res = i1 * i2 * ... in -func (system *sparseR1CS) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { +func (system *compiler) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { vars, k := system.filterConstantProd(append([]frontend.Variable{i1, i2}, in...)) if len(vars) == 0 { @@ -101,7 +100,7 @@ func (system *sparseR1CS) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) } // returns t*m -func (system *sparseR1CS) mulConstant(t compiled.Term, m *big.Int) compiled.Term { +func (system *compiler) mulConstant(t compiled.Term, m *big.Int) compiled.Term { var coef big.Int cid, _, _ := t.Unpack() coef.Set(&system.st.Coeffs[cid]) @@ -112,7 +111,7 @@ func (system *sparseR1CS) mulConstant(t compiled.Term, m *big.Int) compiled.Term } // DivUnchecked returns i1 / i2 . if i1 == i2 == 0, returns 0 -func (system *sparseR1CS) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable { +func (system *compiler) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable { if system.IsConstant(i1) && system.IsConstant(i2) { l := utils.FromInterface(i1) @@ -144,7 +143,7 @@ func (system *sparseR1CS) DivUnchecked(i1, i2 frontend.Variable) frontend.Variab } // Div returns i1 / i2 -func (system *sparseR1CS) Div(i1, i2 frontend.Variable) frontend.Variable { +func (system *compiler) Div(i1, i2 frontend.Variable) frontend.Variable { // note that here we ensure that v2 can't be 0, but it costs us one extra constraint system.Inverse(i2) @@ -153,7 +152,7 @@ func (system *sparseR1CS) Div(i1, i2 frontend.Variable) frontend.Variable { } // Inverse returns res = 1 / i1 -func (system *sparseR1CS) Inverse(i1 frontend.Variable) frontend.Variable { +func (system *compiler) Inverse(i1 frontend.Variable) frontend.Variable { if system.IsConstant(i1) { c := utils.FromInterface(i1) c.ModInverse(&c, system.CurveID.Info().Fr.Modulus()) @@ -175,7 +174,7 @@ func (system *sparseR1CS) Inverse(i1 frontend.Variable) frontend.Variable { // n default value is fr.Bits the number of bits needed to represent a field element // // The result in in little endian (first bit= lsb) -func (system *sparseR1CS) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable { +func (system *compiler) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable { // nbBits nbBits := system.BitLen() @@ -200,7 +199,7 @@ func (system *sparseR1CS) ToBinary(i1 frontend.Variable, n ...int) []frontend.Va return system.toBinary(a, nbBits, false) } -func (system *sparseR1CS) toBinary(a compiled.Term, nbBits int, unsafe bool) []frontend.Variable { +func (system *compiler) toBinary(a compiled.Term, nbBits int, unsafe bool) []frontend.Variable { // allocate the resulting frontend.Variables and bit-constraint them sb := make([]frontend.Variable, nbBits) @@ -238,7 +237,7 @@ func (system *sparseR1CS) toBinary(a compiled.Term, nbBits int, unsafe bool) []f } // FromBinary packs b, seen as a fr.Element in little endian -func (system *sparseR1CS) FromBinary(b ...frontend.Variable) frontend.Variable { +func (system *compiler) FromBinary(b ...frontend.Variable) frontend.Variable { _b := make([]frontend.Variable, len(b)) var c big.Int c.SetUint64(1) @@ -257,7 +256,7 @@ func (system *sparseR1CS) FromBinary(b ...frontend.Variable) frontend.Variable { // Xor returns a ^ b // a and b must be 0 or 1 -func (system *sparseR1CS) Xor(a, b frontend.Variable) frontend.Variable { +func (system *compiler) Xor(a, b frontend.Variable) frontend.Variable { if system.IsConstant(a) && system.IsConstant(b) { _a := utils.FromInterface(a) _b := utils.FromInterface(b) @@ -286,7 +285,7 @@ func (system *sparseR1CS) Xor(a, b frontend.Variable) frontend.Variable { // Or returns a | b // a and b must be 0 or 1 -func (system *sparseR1CS) Or(a, b frontend.Variable) frontend.Variable { +func (system *compiler) Or(a, b frontend.Variable) frontend.Variable { var zero, one big.Int one.SetUint64(1) @@ -328,7 +327,7 @@ func (system *sparseR1CS) Or(a, b frontend.Variable) frontend.Variable { // Or returns a & b // a and b must be 0 or 1 -func (system *sparseR1CS) And(a, b frontend.Variable) frontend.Variable { +func (system *compiler) And(a, b frontend.Variable) frontend.Variable { system.AssertIsBoolean(a) system.AssertIsBoolean(b) return system.Mul(a, b) @@ -338,7 +337,7 @@ func (system *sparseR1CS) And(a, b frontend.Variable) frontend.Variable { // Conditionals // Select if b is true, yields i1 else yields i2 -func (system *sparseR1CS) Select(b frontend.Variable, i1, i2 frontend.Variable) frontend.Variable { +func (system *compiler) Select(b frontend.Variable, i1, i2 frontend.Variable) frontend.Variable { if system.IsConstant(b) { _b := utils.FromInterface(b) @@ -362,7 +361,7 @@ func (system *sparseR1CS) Select(b frontend.Variable, i1, i2 frontend.Variable) // Lookup2 performs a 2-bit lookup between i1, i2, i3, i4 based on bits b0 // and b1. Returns i0 if b0=b1=0, i1 if b0=1 and b1=0, i2 if b0=0 and b1=1 // and i3 if b0=b1=1. -func (system *sparseR1CS) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 frontend.Variable) frontend.Variable { +func (system *compiler) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 frontend.Variable) frontend.Variable { // vars, _ := system.toVariables(b0, b1, i0, i1, i2, i3) // s0, s1 := vars[0], vars[1] @@ -398,7 +397,7 @@ func (system *sparseR1CS) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 front } // IsZero returns 1 if a is zero, 0 otherwise -func (system *sparseR1CS) IsZero(i1 frontend.Variable) frontend.Variable { +func (system *compiler) IsZero(i1 frontend.Variable) frontend.Variable { if system.IsConstant(i1) { a := utils.FromInterface(i1) @@ -427,7 +426,7 @@ func (system *sparseR1CS) IsZero(i1 frontend.Variable) frontend.Variable { } // Cmp returns 1 if i1>i2, 0 if i1=i2, -1 if i1 bound -func (system *sparseR1CS) AssertIsLessOrEqual(v frontend.Variable, bound frontend.Variable) { +func (system *compiler) AssertIsLessOrEqual(v frontend.Variable, bound frontend.Variable) { switch b := bound.(type) { case compiled.Term: system.mustBeLessOrEqVar(v.(compiled.Term), b) @@ -96,7 +96,7 @@ func (system *sparseR1CS) AssertIsLessOrEqual(v frontend.Variable, bound fronten } } -func (system *sparseR1CS) mustBeLessOrEqVar(a compiled.Term, bound compiled.Term) { +func (system *compiler) mustBeLessOrEqVar(a compiled.Term, bound compiled.Term) { debug := system.AddDebugInfo("mustBeLessOrEq", a, " <= ", bound) @@ -143,7 +143,7 @@ func (system *sparseR1CS) mustBeLessOrEqVar(a compiled.Term, bound compiled.Term } -func (system *sparseR1CS) mustBeLessOrEqCst(a compiled.Term, bound big.Int) { +func (system *compiler) mustBeLessOrEqCst(a compiled.Term, bound big.Int) { nbBits := system.BitLen() diff --git a/frontend/cs/scs/compiler.go b/frontend/cs/scs/compiler.go new file mode 100644 index 0000000000..391d852e3d --- /dev/null +++ b/frontend/cs/scs/compiler.go @@ -0,0 +1,682 @@ +/* +Copyright © 2021 ConsenSys Software Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package scs + +import ( + "errors" + "fmt" + "math/big" + "path/filepath" + "reflect" + "runtime" + "sort" + "strconv" + "strings" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/compiled" + "github.com/consensys/gnark/frontend/cs" + "github.com/consensys/gnark/frontend/schema" + bls12377r1cs "github.com/consensys/gnark/internal/backend/bls12-377/cs" + bls12381r1cs "github.com/consensys/gnark/internal/backend/bls12-381/cs" + bls24315r1cs "github.com/consensys/gnark/internal/backend/bls24-315/cs" + bn254r1cs "github.com/consensys/gnark/internal/backend/bn254/cs" + bw6633r1cs "github.com/consensys/gnark/internal/backend/bw6-633/cs" + bw6761r1cs "github.com/consensys/gnark/internal/backend/bw6-761/cs" + "github.com/consensys/gnark/internal/utils" +) + +func NewCompiler(curve ecc.ID) (frontend.Builder, error) { + return newCompiler(curve), nil +} + +type compiler struct { + compiled.ConstraintSystem + Constraints []compiled.SparseR1C + + st cs.CoeffTable + + // map for recording boolean constrained variables (to not constrain them twice) + mtBooleans map[int]struct{} +} + +// initialCapacity has quite some impact on frontend performance, especially on large circuits size +// we may want to add build tags to tune that +func newCompiler(curveID ecc.ID, initialCapacity ...int) *compiler { + capacity := 0 + if len(initialCapacity) > 0 { + capacity = initialCapacity[0] + } + system := compiler{ + ConstraintSystem: compiled.ConstraintSystem{ + + MDebug: make(map[int]int), + MHints: make(map[int]*compiled.Hint), + }, + mtBooleans: make(map[int]struct{}), + Constraints: make([]compiled.SparseR1C, 0, capacity), + st: cs.NewCoeffTable(), + } + + system.st.Coeffs[compiled.CoeffIdZero].SetInt64(0) + system.st.Coeffs[compiled.CoeffIdOne].SetInt64(1) + system.st.Coeffs[compiled.CoeffIdTwo].SetInt64(2) + system.st.Coeffs[compiled.CoeffIdMinusOne].SetInt64(-1) + + system.st.CoeffsIDsInt64[0] = compiled.CoeffIdZero + system.st.CoeffsIDsInt64[1] = compiled.CoeffIdOne + system.st.CoeffsIDsInt64[2] = compiled.CoeffIdTwo + system.st.CoeffsIDsInt64[-1] = compiled.CoeffIdMinusOne + + // system.public.variables = make([]Variable, 0) + // system.secret.variables = make([]Variable, 0) + // system.internal = make([]Variable, 0, capacity) + system.Public = make([]string, 0) + system.Secret = make([]string, 0) + + system.CurveID = curveID + + return &system +} + +// addPlonkConstraint creates a constraint of the for al+br+clr+k=0 +//func (system *SparseR1CS) addPlonkConstraint(l, r, o frontend.Variable, cidl, cidr, cidm1, cidm2, cido, k int, debugID ...int) { +func (system *compiler) addPlonkConstraint(l, r, o compiled.Term, cidl, cidr, cidm1, cidm2, cido, k int, debugID ...int) { + + if len(debugID) > 0 { + system.MDebug[len(system.Constraints)] = debugID[0] + } + + l.SetCoeffID(cidl) + r.SetCoeffID(cidr) + o.SetCoeffID(cido) + + u := l + v := r + u.SetCoeffID(cidm1) + v.SetCoeffID(cidm2) + + //system.Constraints = append(system.Constraints, compiled.SparseR1C{L: _l, R: _r, O: _o, M: [2]compiled.Term{u, v}, K: k}) + system.Constraints = append(system.Constraints, compiled.SparseR1C{L: l, R: r, O: o, M: [2]compiled.Term{u, v}, K: k}) +} + +// newInternalVariable creates a new wire, appends it on the list of wires of the circuit, sets +// the wire's id to the number of wires, and returns it +func (system *compiler) newInternalVariable() compiled.Term { + idx := system.NbInternalVariables + system.NbInternalVariables++ + return compiled.Pack(idx, compiled.CoeffIdOne, schema.Internal) +} + +// AddPublicVariable creates a new Public Variable +func (system *compiler) AddPublicVariable(name string) frontend.Variable { + if system.Schema != nil { + panic("do not call AddPublicVariable in circuit.Define()") + } + idx := len(system.Public) + system.Public = append(system.Public, name) + return compiled.Pack(idx, compiled.CoeffIdOne, schema.Public) +} + +// AddSecretVariable creates a new Secret Variable +func (system *compiler) AddSecretVariable(name string) frontend.Variable { + if system.Schema != nil { + panic("do not call AddSecretVariable in circuit.Define()") + } + idx := len(system.Secret) + system.Secret = append(system.Secret, name) + return compiled.Pack(idx, compiled.CoeffIdOne, schema.Secret) +} + +// reduces redundancy in linear expression +// It factorizes Variable that appears multiple times with != coeff Ids +// To ensure the determinism in the compile process, Variables are stored as public∥secret∥internal∥unset +// for each visibility, the Variables are sorted from lowest ID to highest ID +func (system *compiler) reduce(l compiled.LinearExpression) compiled.LinearExpression { + + // ensure our linear expression is sorted, by visibility and by Variable ID + sort.Sort(l) + + mod := system.CurveID.Info().Fr.Modulus() + c := new(big.Int) + for i := 1; i < len(l); i++ { + pcID, pvID, pVis := l[i-1].Unpack() + ccID, cvID, cVis := l[i].Unpack() + if pVis == cVis && pvID == cvID { + // we have redundancy + c.Add(&system.st.Coeffs[pcID], &system.st.Coeffs[ccID]) + c.Mod(c, mod) + l[i-1].SetCoeffID(system.st.CoeffID(c)) + l = append(l[:i], l[i+1:]...) + i-- + } + } + return l +} + +// to handle wires that don't exist (=coef 0) in a sparse constraint +func (system *compiler) zero() compiled.Term { + var a compiled.Term + return a +} + +// IsBoolean returns true if given variable was marked as boolean in the compiler (see MarkBoolean) +// Use with care; variable may not have been **constrained** to be boolean +// This returns true if the v is a constant and v == 0 || v == 1. +func (system *compiler) IsBoolean(v frontend.Variable) bool { + if b, ok := system.ConstantValue(v); ok { + return b.IsUint64() && b.Uint64() <= 1 + } + _, ok := system.mtBooleans[int(v.(compiled.Term))] + return ok +} + +// MarkBoolean sets (but do not constraint!) v to be boolean +// This is useful in scenarios where a variable is known to be boolean through a constraint +// that is not api.AssertIsBoolean. If v is a constant, this is a no-op. +func (system *compiler) MarkBoolean(v frontend.Variable) { + if system.IsConstant(v) { + return + } + system.mtBooleans[int(v.(compiled.Term))] = struct{}{} +} + +// checkVariables perform post compilation checks on the Variables +// +// 1. checks that all user inputs are referenced in at least one constraint +// 2. checks that all hints are constrained +func (system *compiler) checkVariables() error { + + // TODO @gbotrel add unit test for that. + + cptSecret := len(system.Secret) + cptPublic := len(system.Public) + cptHints := len(system.MHints) + + // compared to R1CS, we may have a circuit which does not have any inputs + // (R1CS always has a constant ONE wire). Check the edge case and omit any + // processing if so. + if cptSecret+cptPublic+cptHints == 0 { + return nil + } + + secretConstrained := make([]bool, cptSecret) + publicConstrained := make([]bool, cptPublic) + + mHintsConstrained := make(map[int]bool) + + // for each constraint, we check the terms and mark our inputs / hints as constrained + processTerm := func(t compiled.Term) { + + // L and M[0] handles the same wire but with a different coeff + visibility := t.VariableVisibility() + vID := t.WireID() + if t.CoeffID() != compiled.CoeffIdZero { + switch visibility { + case schema.Public: + if !publicConstrained[vID] { + publicConstrained[vID] = true + cptPublic-- + } + case schema.Secret: + if !secretConstrained[vID] { + secretConstrained[vID] = true + cptSecret-- + } + case schema.Internal: + if _, ok := system.MHints[vID]; !mHintsConstrained[vID] && ok { + mHintsConstrained[vID] = true + cptHints-- + } + } + } + + } + for _, c := range system.Constraints { + processTerm(c.L) + processTerm(c.R) + processTerm(c.M[0]) + processTerm(c.M[1]) + processTerm(c.O) + if cptHints|cptSecret|cptPublic == 0 { + return nil // we can stop. + } + + } + + // something is a miss, we build the error string + var sbb strings.Builder + if cptSecret != 0 { + sbb.WriteString(strconv.Itoa(cptSecret)) + sbb.WriteString(" unconstrained secret input(s):") + sbb.WriteByte('\n') + for i := 0; i < len(secretConstrained) && cptSecret != 0; i++ { + if !secretConstrained[i] { + sbb.WriteString(system.Secret[i]) + sbb.WriteByte('\n') + cptSecret-- + } + } + sbb.WriteByte('\n') + } + + if cptPublic != 0 { + sbb.WriteString(strconv.Itoa(cptPublic)) + sbb.WriteString(" unconstrained public input(s):") + sbb.WriteByte('\n') + for i := 0; i < len(publicConstrained) && cptPublic != 0; i++ { + if !publicConstrained[i] { + sbb.WriteString(system.Public[i]) + sbb.WriteByte('\n') + cptPublic-- + } + } + sbb.WriteByte('\n') + } + + if cptHints != 0 { + sbb.WriteString(strconv.Itoa(cptHints)) + sbb.WriteString(" unconstrained hints") + sbb.WriteByte('\n') + // TODO we may add more debug info here → idea, in NewHint, take the debug stack, and store in the hint map some + // debugInfo to find where a hint was declared (and not constrained) + } + return errors.New(sbb.String()) +} + +var tVariable reflect.Type + +func init() { + tVariable = reflect.ValueOf(struct{ A frontend.Variable }{}).FieldByName("A").Type() +} + +func (cs *compiler) Compile(opt frontend.CompileConfig) (frontend.CompiledConstraintSystem, error) { + + // ensure all inputs and hints are constrained + if !opt.IgnoreUnconstrainedInputs { + if err := cs.checkVariables(); err != nil { + return nil, err + } + } + + res := compiled.SparseR1CS{ + ConstraintSystem: cs.ConstraintSystem, + Constraints: cs.Constraints, + } + res.NbPublicVariables = len(cs.Public) + res.NbSecretVariables = len(cs.Secret) + + // Logs, DebugInfo and hints are copied, the only thing that will change + // is that ID of the wires will be offseted to take into account the final wire vector ordering + // that is: public wires | secret wires | internal wires + + // shift variable ID + // we want publicWires | privateWires | internalWires + shiftVID := func(oldID int, visibility schema.Visibility) int { + switch visibility { + case schema.Internal: + return oldID + res.NbPublicVariables + res.NbSecretVariables + case schema.Public: + return oldID + case schema.Secret: + return oldID + res.NbPublicVariables + default: + return oldID + } + } + + offsetTermID := func(t *compiled.Term) { + _, VID, visibility := t.Unpack() + t.SetWireID(shiftVID(VID, visibility)) + } + + // offset the IDs of all constraints so that the variables are + // numbered like this: [publicVariables | secretVariables | internalVariables ] + for i := 0; i < len(res.Constraints); i++ { + r1c := &res.Constraints[i] + offsetTermID(&r1c.L) + offsetTermID(&r1c.R) + offsetTermID(&r1c.O) + offsetTermID(&r1c.M[0]) + offsetTermID(&r1c.M[1]) + } + + // we need to offset the ids in Logs & DebugInfo + for i := 0; i < len(cs.Logs); i++ { + for j := 0; j < len(res.Logs[i].ToResolve); j++ { + offsetTermID(&res.Logs[i].ToResolve[j]) + } + } + for i := 0; i < len(cs.DebugInfo); i++ { + for j := 0; j < len(res.DebugInfo[i].ToResolve); j++ { + offsetTermID(&res.DebugInfo[i].ToResolve[j]) + } + } + + // we need to offset the ids in the hints + shiftedMap := make(map[int]*compiled.Hint) +HINTLOOP: + for _, hint := range cs.MHints { + ws := make([]int, len(hint.Wires)) + // we set for all outputs in shiftedMap. If one shifted output + // is in shiftedMap, then all are + for i, vID := range hint.Wires { + ws[i] = shiftVID(vID, schema.Internal) + if _, ok := shiftedMap[ws[i]]; i == 0 && ok { + continue HINTLOOP + } + } + inputs := make([]interface{}, len(hint.Inputs)) + copy(inputs, hint.Inputs) + for j := 0; j < len(inputs); j++ { + switch t := inputs[j].(type) { + case compiled.Term: + offsetTermID(&t) + inputs[j] = t // TODO check if we can remove it + default: + inputs[j] = t + } + } + ch := &compiled.Hint{ID: hint.ID, Inputs: inputs, Wires: ws} + for _, vID := range ws { + shiftedMap[vID] = ch + } + } + res.MHints = shiftedMap + + // build levels + res.Levels = buildLevels(res) + + switch cs.CurveID { + case ecc.BLS12_377: + return bls12377r1cs.NewSparseR1CS(res, cs.st.Coeffs), nil + case ecc.BLS12_381: + return bls12381r1cs.NewSparseR1CS(res, cs.st.Coeffs), nil + case ecc.BN254: + return bn254r1cs.NewSparseR1CS(res, cs.st.Coeffs), nil + case ecc.BW6_761: + return bw6761r1cs.NewSparseR1CS(res, cs.st.Coeffs), nil + case ecc.BLS24_315: + return bls24315r1cs.NewSparseR1CS(res, cs.st.Coeffs), nil + case ecc.BW6_633: + return bw6633r1cs.NewSparseR1CS(res, cs.st.Coeffs), nil + default: + panic("unknown curveID") + } + +} + +func (cs *compiler) SetSchema(s *schema.Schema) { + if cs.Schema != nil { + panic("SetSchema called multiple times") + } + cs.Schema = s +} + +func buildLevels(ccs compiled.SparseR1CS) [][]int { + + b := levelBuilder{ + mWireToNode: make(map[int]int, ccs.NbInternalVariables), // at which node we resolved which wire + nodeLevels: make([]int, len(ccs.Constraints)), // level of a node + mLevels: make(map[int]int), // level counts + ccs: ccs, + nbInputs: ccs.NbPublicVariables + ccs.NbSecretVariables, + } + + // for each constraint, we're going to find its direct dependencies + // that is, wires (solved by previous constraints) on which it depends + // each of these dependencies is tagged with a level + // current constraint will be tagged with max(level) + 1 + for cID, c := range ccs.Constraints { + + b.nodeLevel = 0 + + b.processTerm(c.L, cID) + b.processTerm(c.R, cID) + b.processTerm(c.O, cID) + + b.nodeLevels[cID] = b.nodeLevel + b.mLevels[b.nodeLevel]++ + + } + + levels := make([][]int, len(b.mLevels)) + for i := 0; i < len(levels); i++ { + // allocate memory + levels[i] = make([]int, 0, b.mLevels[i]) + } + + for n, l := range b.nodeLevels { + levels[l] = append(levels[l], n) + } + + return levels +} + +type levelBuilder struct { + ccs compiled.SparseR1CS + nbInputs int + + mWireToNode map[int]int // at which node we resolved which wire + nodeLevels []int // level per node + mLevels map[int]int // number of constraint per level + + nodeLevel int // current level +} + +func (b *levelBuilder) processTerm(t compiled.Term, cID int) { + wID := t.WireID() + if wID < b.nbInputs { + // it's a input, we ignore it + return + } + + // if we know a which constraint solves this wire, then it's a dependency + n, ok := b.mWireToNode[wID] + if ok { + if n != cID { // can happen with hints... + // we add a dependency, check if we need to increment our current level + if b.nodeLevels[n] >= b.nodeLevel { + b.nodeLevel = b.nodeLevels[n] + 1 // we are at the next level at least since we depend on it + } + } + return + } + + // check if it's a hint and mark all the output wires + if h, ok := b.ccs.MHints[wID]; ok { + + for _, in := range h.Inputs { + switch t := in.(type) { + case compiled.Variable: + for _, tt := range t.LinExp { + b.processTerm(tt, cID) + } + case compiled.LinearExpression: + for _, tt := range t { + b.processTerm(tt, cID) + } + case compiled.Term: + b.processTerm(t, cID) + } + } + + for _, hwid := range h.Wires { + b.mWireToNode[hwid] = cID + } + + return + } + + // mark this wire solved by current node + b.mWireToNode[wID] = cID + +} + +// ConstantValue returns the big.Int value of v. It +// panics if v.IsConstant() == false +func (system *compiler) ConstantValue(v frontend.Variable) (*big.Int, bool) { + switch t := v.(type) { + case compiled.Term: + return nil, false + default: + res := utils.FromInterface(t) + return &res, true + } +} + +func (system *compiler) Backend() backend.ID { + return backend.PLONK +} + +// Tag creates a tag at a given place in a circuit. The state of the tag may contain informations needed to +// measure constraints, variables and coefficients creations through AddCounter +func (system *compiler) Tag(name string) frontend.Tag { + _, file, line, _ := runtime.Caller(1) + + return frontend.Tag{ + Name: fmt.Sprintf("%s[%s:%d]", name, filepath.Base(file), line), + VID: system.NbInternalVariables, + CID: len(system.Constraints), + } +} + +// AddCounter measures the number of constraints, variables and coefficients created between two tags +// note that the PlonK statistics are contextual since there is a post-compile phase where linear expressions +// are factorized. That is, measuring 2 times the "repeating" piece of circuit may give less constraints the second time +func (system *compiler) AddCounter(from, to frontend.Tag) { + system.Counters = append(system.Counters, compiled.Counter{ + From: from.Name, + To: to.Name, + NbVariables: to.VID - from.VID, + NbConstraints: to.CID - from.CID, + CurveID: system.CurveID, + BackendID: backend.PLONK, + }) +} + +// NewHint initializes internal variables whose value will be evaluated using +// the provided hint function at run time from the inputs. Inputs must be either +// variables or convertible to *big.Int. The function returns an error if the +// number of inputs is not compatible with f. +// +// The hint function is provided at the proof creation time and is not embedded +// into the circuit. From the backend point of view, the variable returned by +// the hint function is equivalent to the user-supplied witness, but its actual +// value is assigned by the solver, not the caller. +// +// No new constraints are added to the newly created wire and must be added +// manually in the circuit. Failing to do so leads to solver failure. +func (system *compiler) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { + + if nbOutputs <= 0 { + return nil, fmt.Errorf("hint function must return at least one output") + } + + hintInputs := make([]interface{}, len(inputs)) + + // ensure inputs are set and pack them in a []uint64 + for i, in := range inputs { + switch t := in.(type) { + case compiled.Term: + hintInputs[i] = t + default: + hintInputs[i] = utils.FromInterface(in) + } + } + + // prepare wires + varIDs := make([]int, nbOutputs) + res := make([]frontend.Variable, len(varIDs)) + for i := range varIDs { + r := system.newInternalVariable() + _, vID, _ := r.Unpack() + varIDs[i] = vID + res[i] = r + } + + ch := &compiled.Hint{ID: f.UUID(), Inputs: hintInputs, Wires: varIDs} + for _, vID := range varIDs { + system.MHints[vID] = ch + } + + return res, nil +} + +// returns in split into a slice of compiledTerm and the sum of all constants in in as a bigInt +func (system *compiler) filterConstantSum(in []frontend.Variable) ([]compiled.Term, big.Int) { + res := make([]compiled.Term, 0, len(in)) + var b big.Int + for i := 0; i < len(in); i++ { + switch t := in[i].(type) { + case compiled.Term: + res = append(res, t) + default: + n := utils.FromInterface(t) + b.Add(&b, &n) + } + } + return res, b +} + +// returns in split into a slice of compiledTerm and the product of all constants in in as a bigInt +func (system *compiler) filterConstantProd(in []frontend.Variable) ([]compiled.Term, big.Int) { + res := make([]compiled.Term, 0, len(in)) + var b big.Int + b.SetInt64(1) + for i := 0; i < len(in); i++ { + switch t := in[i].(type) { + case compiled.Term: + res = append(res, t) + default: + n := utils.FromInterface(t) + b.Mul(&b, &n).Mod(&b, system.CurveID.Info().Fr.Modulus()) + } + } + return res, b +} + +func (system *compiler) splitSum(acc compiled.Term, r []compiled.Term) compiled.Term { + + // floor case + if len(r) == 0 { + return acc + } + + cl, _, _ := acc.Unpack() + cr, _, _ := r[0].Unpack() + o := system.newInternalVariable() + system.addPlonkConstraint(acc, r[0], o, cl, cr, compiled.CoeffIdZero, compiled.CoeffIdZero, compiled.CoeffIdMinusOne, compiled.CoeffIdZero) + return system.splitSum(o, r[1:]) +} + +func (system *compiler) splitProd(acc compiled.Term, r []compiled.Term) compiled.Term { + + // floor case + if len(r) == 0 { + return acc + } + + cl, _, _ := acc.Unpack() + cr, _, _ := r[0].Unpack() + o := system.newInternalVariable() + system.addPlonkConstraint(acc, r[0], o, compiled.CoeffIdZero, compiled.CoeffIdZero, cl, cr, compiled.CoeffIdMinusOne, compiled.CoeffIdZero) + return system.splitProd(o, r[1:]) +} diff --git a/frontend/cs/scs/conversion.go b/frontend/cs/scs/conversion.go deleted file mode 100644 index 4a4dccc4e4..0000000000 --- a/frontend/cs/scs/conversion.go +++ /dev/null @@ -1,253 +0,0 @@ -/* -Copyright © 2021 ConsenSys Software Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package scs - -import ( - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/compiled" - "github.com/consensys/gnark/frontend/schema" - bls12377r1cs "github.com/consensys/gnark/internal/backend/bls12-377/cs" - bls12381r1cs "github.com/consensys/gnark/internal/backend/bls12-381/cs" - bls24315r1cs "github.com/consensys/gnark/internal/backend/bls24-315/cs" - bn254r1cs "github.com/consensys/gnark/internal/backend/bn254/cs" - bw6633r1cs "github.com/consensys/gnark/internal/backend/bw6-633/cs" - bw6761r1cs "github.com/consensys/gnark/internal/backend/bw6-761/cs" -) - -func (cs *sparseR1CS) Compile(opt frontend.CompileConfig) (frontend.CompiledConstraintSystem, error) { - - // ensure all inputs and hints are constrained - if !opt.IgnoreUnconstrainedInputs { - if err := cs.checkVariables(); err != nil { - return nil, err - } - } - - res := compiled.SparseR1CS{ - ConstraintSystem: cs.ConstraintSystem, - Constraints: cs.Constraints, - } - res.NbPublicVariables = len(cs.Public) - res.NbSecretVariables = len(cs.Secret) - - // Logs, DebugInfo and hints are copied, the only thing that will change - // is that ID of the wires will be offseted to take into account the final wire vector ordering - // that is: public wires | secret wires | internal wires - - // shift variable ID - // we want publicWires | privateWires | internalWires - shiftVID := func(oldID int, visibility schema.Visibility) int { - switch visibility { - case schema.Internal: - return oldID + res.NbPublicVariables + res.NbSecretVariables - case schema.Public: - return oldID - case schema.Secret: - return oldID + res.NbPublicVariables - default: - return oldID - } - } - - offsetTermID := func(t *compiled.Term) { - _, VID, visibility := t.Unpack() - t.SetWireID(shiftVID(VID, visibility)) - } - - // offset the IDs of all constraints so that the variables are - // numbered like this: [publicVariables | secretVariables | internalVariables ] - for i := 0; i < len(res.Constraints); i++ { - r1c := &res.Constraints[i] - offsetTermID(&r1c.L) - offsetTermID(&r1c.R) - offsetTermID(&r1c.O) - offsetTermID(&r1c.M[0]) - offsetTermID(&r1c.M[1]) - } - - // we need to offset the ids in Logs & DebugInfo - for i := 0; i < len(cs.Logs); i++ { - for j := 0; j < len(res.Logs[i].ToResolve); j++ { - offsetTermID(&res.Logs[i].ToResolve[j]) - } - } - for i := 0; i < len(cs.DebugInfo); i++ { - for j := 0; j < len(res.DebugInfo[i].ToResolve); j++ { - offsetTermID(&res.DebugInfo[i].ToResolve[j]) - } - } - - // we need to offset the ids in the hints - shiftedMap := make(map[int]*compiled.Hint) -HINTLOOP: - for _, hint := range cs.MHints { - ws := make([]int, len(hint.Wires)) - // we set for all outputs in shiftedMap. If one shifted output - // is in shiftedMap, then all are - for i, vID := range hint.Wires { - ws[i] = shiftVID(vID, schema.Internal) - if _, ok := shiftedMap[ws[i]]; i == 0 && ok { - continue HINTLOOP - } - } - inputs := make([]interface{}, len(hint.Inputs)) - copy(inputs, hint.Inputs) - for j := 0; j < len(inputs); j++ { - switch t := inputs[j].(type) { - case compiled.Term: - offsetTermID(&t) - inputs[j] = t // TODO check if we can remove it - default: - inputs[j] = t - } - } - ch := &compiled.Hint{ID: hint.ID, Inputs: inputs, Wires: ws} - for _, vID := range ws { - shiftedMap[vID] = ch - } - } - res.MHints = shiftedMap - - // build levels - res.Levels = buildLevels(res) - - switch cs.CurveID { - case ecc.BLS12_377: - return bls12377r1cs.NewSparseR1CS(res, cs.st.Coeffs), nil - case ecc.BLS12_381: - return bls12381r1cs.NewSparseR1CS(res, cs.st.Coeffs), nil - case ecc.BN254: - return bn254r1cs.NewSparseR1CS(res, cs.st.Coeffs), nil - case ecc.BW6_761: - return bw6761r1cs.NewSparseR1CS(res, cs.st.Coeffs), nil - case ecc.BLS24_315: - return bls24315r1cs.NewSparseR1CS(res, cs.st.Coeffs), nil - case ecc.BW6_633: - return bw6633r1cs.NewSparseR1CS(res, cs.st.Coeffs), nil - default: - panic("unknown curveID") - } - -} - -func (cs *sparseR1CS) SetSchema(s *schema.Schema) { - if cs.Schema != nil { - panic("SetSchema called multiple times") - } - cs.Schema = s -} - -func buildLevels(ccs compiled.SparseR1CS) [][]int { - - b := levelBuilder{ - mWireToNode: make(map[int]int, ccs.NbInternalVariables), // at which node we resolved which wire - nodeLevels: make([]int, len(ccs.Constraints)), // level of a node - mLevels: make(map[int]int), // level counts - ccs: ccs, - nbInputs: ccs.NbPublicVariables + ccs.NbSecretVariables, - } - - // for each constraint, we're going to find its direct dependencies - // that is, wires (solved by previous constraints) on which it depends - // each of these dependencies is tagged with a level - // current constraint will be tagged with max(level) + 1 - for cID, c := range ccs.Constraints { - - b.nodeLevel = 0 - - b.processTerm(c.L, cID) - b.processTerm(c.R, cID) - b.processTerm(c.O, cID) - - b.nodeLevels[cID] = b.nodeLevel - b.mLevels[b.nodeLevel]++ - - } - - levels := make([][]int, len(b.mLevels)) - for i := 0; i < len(levels); i++ { - // allocate memory - levels[i] = make([]int, 0, b.mLevels[i]) - } - - for n, l := range b.nodeLevels { - levels[l] = append(levels[l], n) - } - - return levels -} - -type levelBuilder struct { - ccs compiled.SparseR1CS - nbInputs int - - mWireToNode map[int]int // at which node we resolved which wire - nodeLevels []int // level per node - mLevels map[int]int // number of constraint per level - - nodeLevel int // current level -} - -func (b *levelBuilder) processTerm(t compiled.Term, cID int) { - wID := t.WireID() - if wID < b.nbInputs { - // it's a input, we ignore it - return - } - - // if we know a which constraint solves this wire, then it's a dependency - n, ok := b.mWireToNode[wID] - if ok { - if n != cID { // can happen with hints... - // we add a dependency, check if we need to increment our current level - if b.nodeLevels[n] >= b.nodeLevel { - b.nodeLevel = b.nodeLevels[n] + 1 // we are at the next level at least since we depend on it - } - } - return - } - - // check if it's a hint and mark all the output wires - if h, ok := b.ccs.MHints[wID]; ok { - - for _, in := range h.Inputs { - switch t := in.(type) { - case compiled.Variable: - for _, tt := range t.LinExp { - b.processTerm(tt, cID) - } - case compiled.LinearExpression: - for _, tt := range t { - b.processTerm(tt, cID) - } - case compiled.Term: - b.processTerm(t, cID) - } - } - - for _, hwid := range h.Wires { - b.mWireToNode[hwid] = cID - } - - return - } - - // mark this wire solved by current node - b.mWireToNode[wID] = cID - -} diff --git a/frontend/cs/scs/sparse_r1cs.go b/frontend/cs/scs/sparse_r1cs.go deleted file mode 100644 index 06f78c0814..0000000000 --- a/frontend/cs/scs/sparse_r1cs.go +++ /dev/null @@ -1,296 +0,0 @@ -/* -Copyright © 2021 ConsenSys Software Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package scs - -import ( - "errors" - "math/big" - "reflect" - "sort" - "strconv" - "strings" - - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/compiled" - "github.com/consensys/gnark/frontend/cs" - "github.com/consensys/gnark/frontend/schema" -) - -func NewCompiler(curve ecc.ID) (frontend.Builder, error) { - return newSparseR1CS(curve), nil -} - -type sparseR1CS struct { - compiled.ConstraintSystem - Constraints []compiled.SparseR1C - - st cs.CoeffTable - - // map for recording boolean constrained variables (to not constrain them twice) - mtBooleans map[int]struct{} -} - -// initialCapacity has quite some impact on frontend performance, especially on large circuits size -// we may want to add build tags to tune that -func newSparseR1CS(curveID ecc.ID, initialCapacity ...int) *sparseR1CS { - capacity := 0 - if len(initialCapacity) > 0 { - capacity = initialCapacity[0] - } - system := sparseR1CS{ - ConstraintSystem: compiled.ConstraintSystem{ - - MDebug: make(map[int]int), - MHints: make(map[int]*compiled.Hint), - }, - mtBooleans: make(map[int]struct{}), - Constraints: make([]compiled.SparseR1C, 0, capacity), - st: cs.NewCoeffTable(), - } - - system.st.Coeffs[compiled.CoeffIdZero].SetInt64(0) - system.st.Coeffs[compiled.CoeffIdOne].SetInt64(1) - system.st.Coeffs[compiled.CoeffIdTwo].SetInt64(2) - system.st.Coeffs[compiled.CoeffIdMinusOne].SetInt64(-1) - - system.st.CoeffsIDsInt64[0] = compiled.CoeffIdZero - system.st.CoeffsIDsInt64[1] = compiled.CoeffIdOne - system.st.CoeffsIDsInt64[2] = compiled.CoeffIdTwo - system.st.CoeffsIDsInt64[-1] = compiled.CoeffIdMinusOne - - // system.public.variables = make([]Variable, 0) - // system.secret.variables = make([]Variable, 0) - // system.internal = make([]Variable, 0, capacity) - system.Public = make([]string, 0) - system.Secret = make([]string, 0) - - system.CurveID = curveID - - return &system -} - -// addPlonkConstraint creates a constraint of the for al+br+clr+k=0 -//func (system *SparseR1CS) addPlonkConstraint(l, r, o frontend.Variable, cidl, cidr, cidm1, cidm2, cido, k int, debugID ...int) { -func (system *sparseR1CS) addPlonkConstraint(l, r, o compiled.Term, cidl, cidr, cidm1, cidm2, cido, k int, debugID ...int) { - - if len(debugID) > 0 { - system.MDebug[len(system.Constraints)] = debugID[0] - } - - l.SetCoeffID(cidl) - r.SetCoeffID(cidr) - o.SetCoeffID(cido) - - u := l - v := r - u.SetCoeffID(cidm1) - v.SetCoeffID(cidm2) - - //system.Constraints = append(system.Constraints, compiled.SparseR1C{L: _l, R: _r, O: _o, M: [2]compiled.Term{u, v}, K: k}) - system.Constraints = append(system.Constraints, compiled.SparseR1C{L: l, R: r, O: o, M: [2]compiled.Term{u, v}, K: k}) -} - -// newInternalVariable creates a new wire, appends it on the list of wires of the circuit, sets -// the wire's id to the number of wires, and returns it -func (system *sparseR1CS) newInternalVariable() compiled.Term { - idx := system.NbInternalVariables - system.NbInternalVariables++ - return compiled.Pack(idx, compiled.CoeffIdOne, schema.Internal) -} - -// AddPublicVariable creates a new Public Variable -func (system *sparseR1CS) AddPublicVariable(name string) frontend.Variable { - if system.Schema != nil { - panic("do not call AddPublicVariable in circuit.Define()") - } - idx := len(system.Public) - system.Public = append(system.Public, name) - return compiled.Pack(idx, compiled.CoeffIdOne, schema.Public) -} - -// AddSecretVariable creates a new Secret Variable -func (system *sparseR1CS) AddSecretVariable(name string) frontend.Variable { - if system.Schema != nil { - panic("do not call AddSecretVariable in circuit.Define()") - } - idx := len(system.Secret) - system.Secret = append(system.Secret, name) - return compiled.Pack(idx, compiled.CoeffIdOne, schema.Secret) -} - -// reduces redundancy in linear expression -// It factorizes Variable that appears multiple times with != coeff Ids -// To ensure the determinism in the compile process, Variables are stored as public∥secret∥internal∥unset -// for each visibility, the Variables are sorted from lowest ID to highest ID -func (system *sparseR1CS) reduce(l compiled.LinearExpression) compiled.LinearExpression { - - // ensure our linear expression is sorted, by visibility and by Variable ID - sort.Sort(l) - - mod := system.CurveID.Info().Fr.Modulus() - c := new(big.Int) - for i := 1; i < len(l); i++ { - pcID, pvID, pVis := l[i-1].Unpack() - ccID, cvID, cVis := l[i].Unpack() - if pVis == cVis && pvID == cvID { - // we have redundancy - c.Add(&system.st.Coeffs[pcID], &system.st.Coeffs[ccID]) - c.Mod(c, mod) - l[i-1].SetCoeffID(system.st.CoeffID(c)) - l = append(l[:i], l[i+1:]...) - i-- - } - } - return l -} - -// to handle wires that don't exist (=coef 0) in a sparse constraint -func (system *sparseR1CS) zero() compiled.Term { - var a compiled.Term - return a -} - -// IsBoolean returns true if given variable was marked as boolean in the compiler (see MarkBoolean) -// Use with care; variable may not have been **constrained** to be boolean -// This returns true if the v is a constant and v == 0 || v == 1. -func (system *sparseR1CS) IsBoolean(v frontend.Variable) bool { - if b, ok := system.ConstantValue(v); ok { - return b.IsUint64() && b.Uint64() <= 1 - } - _, ok := system.mtBooleans[int(v.(compiled.Term))] - return ok -} - -// MarkBoolean sets (but do not constraint!) v to be boolean -// This is useful in scenarios where a variable is known to be boolean through a constraint -// that is not api.AssertIsBoolean. If v is a constant, this is a no-op. -func (system *sparseR1CS) MarkBoolean(v frontend.Variable) { - if system.IsConstant(v) { - return - } - system.mtBooleans[int(v.(compiled.Term))] = struct{}{} -} - -// checkVariables perform post compilation checks on the Variables -// -// 1. checks that all user inputs are referenced in at least one constraint -// 2. checks that all hints are constrained -func (system *sparseR1CS) checkVariables() error { - - // TODO @gbotrel add unit test for that. - - cptSecret := len(system.Secret) - cptPublic := len(system.Public) - cptHints := len(system.MHints) - - // compared to R1CS, we may have a circuit which does not have any inputs - // (R1CS always has a constant ONE wire). Check the edge case and omit any - // processing if so. - if cptSecret+cptPublic+cptHints == 0 { - return nil - } - - secretConstrained := make([]bool, cptSecret) - publicConstrained := make([]bool, cptPublic) - - mHintsConstrained := make(map[int]bool) - - // for each constraint, we check the terms and mark our inputs / hints as constrained - processTerm := func(t compiled.Term) { - - // L and M[0] handles the same wire but with a different coeff - visibility := t.VariableVisibility() - vID := t.WireID() - if t.CoeffID() != compiled.CoeffIdZero { - switch visibility { - case schema.Public: - if !publicConstrained[vID] { - publicConstrained[vID] = true - cptPublic-- - } - case schema.Secret: - if !secretConstrained[vID] { - secretConstrained[vID] = true - cptSecret-- - } - case schema.Internal: - if _, ok := system.MHints[vID]; !mHintsConstrained[vID] && ok { - mHintsConstrained[vID] = true - cptHints-- - } - } - } - - } - for _, c := range system.Constraints { - processTerm(c.L) - processTerm(c.R) - processTerm(c.M[0]) - processTerm(c.M[1]) - processTerm(c.O) - if cptHints|cptSecret|cptPublic == 0 { - return nil // we can stop. - } - - } - - // something is a miss, we build the error string - var sbb strings.Builder - if cptSecret != 0 { - sbb.WriteString(strconv.Itoa(cptSecret)) - sbb.WriteString(" unconstrained secret input(s):") - sbb.WriteByte('\n') - for i := 0; i < len(secretConstrained) && cptSecret != 0; i++ { - if !secretConstrained[i] { - sbb.WriteString(system.Secret[i]) - sbb.WriteByte('\n') - cptSecret-- - } - } - sbb.WriteByte('\n') - } - - if cptPublic != 0 { - sbb.WriteString(strconv.Itoa(cptPublic)) - sbb.WriteString(" unconstrained public input(s):") - sbb.WriteByte('\n') - for i := 0; i < len(publicConstrained) && cptPublic != 0; i++ { - if !publicConstrained[i] { - sbb.WriteString(system.Public[i]) - sbb.WriteByte('\n') - cptPublic-- - } - } - sbb.WriteByte('\n') - } - - if cptHints != 0 { - sbb.WriteString(strconv.Itoa(cptHints)) - sbb.WriteString(" unconstrained hints") - sbb.WriteByte('\n') - // TODO we may add more debug info here → idea, in NewHint, take the debug stack, and store in the hint map some - // debugInfo to find where a hint was declared (and not constrained) - } - return errors.New(sbb.String()) -} - -var tVariable reflect.Type - -func init() { - tVariable = reflect.ValueOf(struct{ A frontend.Variable }{}).FieldByName("A").Type() -} From 6034277a342cec83e97b3ec95110cdd3206ab85e Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 28 Feb 2022 13:48:54 -0600 Subject: [PATCH 13/20] perf(plonk): IsConstant -> ConstantValue --- frontend/cs/r1cs/compiler.go | 5 +- frontend/cs/scs/api.go | 111 +++++++++++++----------------- frontend/cs/scs/api_assertions.go | 24 ++++--- frontend/cs/scs/compiler.go | 6 +- 4 files changed, 67 insertions(+), 79 deletions(-) diff --git a/frontend/cs/r1cs/compiler.go b/frontend/cs/r1cs/compiler.go index 4b6b14c5e6..466b34aedd 100644 --- a/frontend/cs/r1cs/compiler.go +++ b/frontend/cs/r1cs/compiler.go @@ -204,7 +204,10 @@ func (system *compiler) setCoeff(v compiled.Term, coeff *big.Int) compiled.Term // This is useful in scenarios where a variable is known to be boolean through a constraint // that is not api.AssertIsBoolean. If v is a constant, this is a no-op. func (system *compiler) MarkBoolean(v frontend.Variable) { - if _, ok := system.ConstantValue(v); ok { + if b, ok := system.ConstantValue(v); ok { + if !(b.IsUint64() && b.Uint64() <= 1) { + panic("MarkBoolean called a non-boolean constant") + } return } // v is a linear expression diff --git a/frontend/cs/scs/api.go b/frontend/cs/scs/api.go index 48816556f7..67796cb23e 100644 --- a/frontend/cs/scs/api.go +++ b/frontend/cs/scs/api.go @@ -29,7 +29,6 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/utils" ) // Add returns res = i1+i2+...in @@ -112,25 +111,26 @@ func (system *compiler) mulConstant(t compiled.Term, m *big.Int) compiled.Term { // DivUnchecked returns i1 / i2 . if i1 == i2 == 0, returns 0 func (system *compiler) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable { + c1, i1Constant := system.ConstantValue(i1) + c2, i2Constant := system.ConstantValue(i2) - if system.IsConstant(i1) && system.IsConstant(i2) { - l := utils.FromInterface(i1) - r := utils.FromInterface(i2) + if i1Constant && i2Constant { + l := c1 + r := c2 q := system.CurveID.Info().Fr.Modulus() - return r.ModInverse(&r, q). - Mul(&l, &r). - Mod(&r, q) + return r.ModInverse(r, q). + Mul(l, r). + Mod(r, q) } - if system.IsConstant(i2) { - c := utils.FromInterface(i2) + if i2Constant { + c := c2 m := system.CurveID.Info().Fr.Modulus() - c.ModInverse(&c, m) - return system.mulConstant(i1.(compiled.Term), &c) + c.ModInverse(c, m) + return system.mulConstant(i1.(compiled.Term), c) } - if system.IsConstant(i1) { + if i1Constant { res := system.Inverse(i2) - m := utils.FromInterface(i1) - return system.mulConstant(res.(compiled.Term), &m) + return system.mulConstant(res.(compiled.Term), c1) } res := system.newInternalVariable() @@ -153,9 +153,8 @@ func (system *compiler) Div(i1, i2 frontend.Variable) frontend.Variable { // Inverse returns res = 1 / i1 func (system *compiler) Inverse(i1 frontend.Variable) frontend.Variable { - if system.IsConstant(i1) { - c := utils.FromInterface(i1) - c.ModInverse(&c, system.CurveID.Info().Fr.Modulus()) + if c, ok := system.ConstantValue(i1); ok { + c.ModInverse(c, system.CurveID.Info().Fr.Modulus()) return c } t := i1.(compiled.Term) @@ -186,8 +185,7 @@ func (system *compiler) ToBinary(i1 frontend.Variable, n ...int) []frontend.Vari } // if a is a constant, work with the big int value. - if system.IsConstant(i1) { - c := utils.FromInterface(i1) + if c, ok := system.ConstantValue(i1); ok { b := make([]frontend.Variable, nbBits) for i := 0; i < len(b); i++ { b[i] = c.Bit(i) @@ -257,23 +255,25 @@ func (system *compiler) FromBinary(b ...frontend.Variable) frontend.Variable { // Xor returns a ^ b // a and b must be 0 or 1 func (system *compiler) Xor(a, b frontend.Variable) frontend.Variable { - if system.IsConstant(a) && system.IsConstant(b) { - _a := utils.FromInterface(a) - _b := utils.FromInterface(b) - _a.Xor(&_a, &_b) + _a, aConstant := system.ConstantValue(a) + _b, bConstant := system.ConstantValue(b) + + if aConstant && bConstant { + _a.Xor(_a, _b) return _a } res := system.newInternalVariable() - if system.IsConstant(a) { + if aConstant { a, b = b, a + bConstant = aConstant + _b = _a } - if system.IsConstant(b) { + if bConstant { l := a.(compiled.Term) r := l - _b := utils.FromInterface(b) one := big.NewInt(1) - _b.Lsh(&_b, 1).Sub(&_b, one) - idl := system.st.CoeffID(&_b) + _b.Lsh(_b, 1).Sub(_b, one) + idl := system.st.CoeffID(_b) system.addPlonkConstraint(l, r, res, idl, compiled.CoeffIdZero, compiled.CoeffIdZero, compiled.CoeffIdZero, compiled.CoeffIdOne, compiled.CoeffIdZero) return res } @@ -286,34 +286,31 @@ func (system *compiler) Xor(a, b frontend.Variable) frontend.Variable { // Or returns a | b // a and b must be 0 or 1 func (system *compiler) Or(a, b frontend.Variable) frontend.Variable { + _a, aConstant := system.ConstantValue(a) + _b, bConstant := system.ConstantValue(b) - var zero, one big.Int - one.SetUint64(1) - - if system.IsConstant(a) && system.IsConstant(b) { - _a := utils.FromInterface(a) - _b := utils.FromInterface(b) - _a.Or(&_a, &_b) + if aConstant && bConstant { + _a.Or(_a, _b) return _a } res := system.newInternalVariable() - if system.IsConstant(a) { + if aConstant { a, b = b, a + _b = _a + bConstant = aConstant } - if system.IsConstant(b) { - _b := utils.FromInterface(b) - + if bConstant { l := a.(compiled.Term) r := l - if _b.Cmp(&one) != 0 && _b.Cmp(&zero) != 0 { + if !(_b.IsUint64() && (_b.Uint64() <= 1)) { panic(fmt.Sprintf("%s should be 0 or 1", _b.String())) } system.AssertIsBoolean(a) one := big.NewInt(1) - _b.Sub(&_b, one) - idl := system.st.CoeffID(&_b) + _b.Sub(_b, one) + idl := system.st.CoeffID(_b) system.addPlonkConstraint(l, r, res, idl, compiled.CoeffIdZero, compiled.CoeffIdZero, compiled.CoeffIdZero, compiled.CoeffIdOne, compiled.CoeffIdZero) return res } @@ -338,15 +335,13 @@ func (system *compiler) And(a, b frontend.Variable) frontend.Variable { // Select if b is true, yields i1 else yields i2 func (system *compiler) Select(b frontend.Variable, i1, i2 frontend.Variable) frontend.Variable { + _b, bConstant := system.ConstantValue(b) - if system.IsConstant(b) { - _b := utils.FromInterface(b) - var t big.Int - one := big.NewInt(1) - if _b.Cmp(&t) != 0 && _b.Cmp(one) != 0 { - panic("b should be a boolean") + if bConstant { + if !(_b.IsUint64() && (_b.Uint64() <= 1)) { + panic(fmt.Sprintf("%s should be 0 or 1", _b.String())) } - if _b.Cmp(&t) == 0 { + if _b.Uint64() == 0 { return i2 } return i1 @@ -398,11 +393,8 @@ func (system *compiler) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 fronten // IsZero returns 1 if a is zero, 0 otherwise func (system *compiler) IsZero(i1 frontend.Variable) frontend.Variable { - - if system.IsConstant(i1) { - a := utils.FromInterface(i1) - var zero big.Int - if a.Cmp(&zero) != 0 { + if a, ok := system.ConstantValue(i1); ok { + if !(a.IsUint64() && a.Uint64() == 0) { panic("input should be zero") } return 1 @@ -534,17 +526,6 @@ func printArg(log *compiled.LogEntry, sbb *strings.Builder, a frontend.Variable) sbb.WriteByte('}') } -// IsConstant returns true if v is a constant known at compile time -func (system *compiler) IsConstant(v frontend.Variable) bool { - switch t := v.(type) { - case compiled.Term: - return false - default: - utils.FromInterface(t) - return true - } -} - func (system *compiler) Compiler() frontend.Compiler { return system } diff --git a/frontend/cs/scs/api_assertions.go b/frontend/cs/scs/api_assertions.go index bb795511c2..841e3da54b 100644 --- a/frontend/cs/scs/api_assertions.go +++ b/frontend/cs/scs/api_assertions.go @@ -28,24 +28,27 @@ import ( // AssertIsEqual fails if i1 != i2 func (system *compiler) AssertIsEqual(i1, i2 frontend.Variable) { - if system.IsConstant(i1) && system.IsConstant(i2) { - a := utils.FromInterface(i1) - b := utils.FromInterface(i2) - if a.Cmp(&b) != 0 { + c1, i1Constant := system.ConstantValue(i1) + c2, i2Constant := system.ConstantValue(i2) + + if i1Constant && i2Constant { + if c1.Cmp(c2) != 0 { panic("i1, i2 should be equal") } return } - if system.IsConstant(i1) { + if i1Constant { i1, i2 = i2, i1 + i2Constant = i1Constant + c2 = c1 } - if system.IsConstant(i2) { + if i2Constant { l := i1.(compiled.Term) lc, _, _ := l.Unpack() - k := utils.FromInterface(i2) + k := c2 debug := system.AddDebugInfo("assertIsEqual", l, "+", i2, " == 0") - k.Neg(&k) - _k := system.st.CoeffID(&k) + k.Neg(k) + _k := system.st.CoeffID(k) system.addPlonkConstraint(l, system.zero(), system.zero(), lc, compiled.CoeffIdZero, compiled.CoeffIdZero, compiled.CoeffIdZero, compiled.CoeffIdZero, _k, debug) return } @@ -65,8 +68,7 @@ func (system *compiler) AssertIsDifferent(i1, i2 frontend.Variable) { // AssertIsBoolean fails if v != 0 ∥ v != 1 func (system *compiler) AssertIsBoolean(i1 frontend.Variable) { - if system.IsConstant(i1) { - c := utils.FromInterface(i1) + if c, ok := system.ConstantValue(i1); ok { if !(c.IsUint64() && (c.Uint64() == 0 || c.Uint64() == 1)) { panic(fmt.Sprintf("assertIsBoolean failed: constant(%s)", c.String())) } diff --git a/frontend/cs/scs/compiler.go b/frontend/cs/scs/compiler.go index 391d852e3d..1215535d53 100644 --- a/frontend/cs/scs/compiler.go +++ b/frontend/cs/scs/compiler.go @@ -192,8 +192,10 @@ func (system *compiler) IsBoolean(v frontend.Variable) bool { // This is useful in scenarios where a variable is known to be boolean through a constraint // that is not api.AssertIsBoolean. If v is a constant, this is a no-op. func (system *compiler) MarkBoolean(v frontend.Variable) { - if system.IsConstant(v) { - return + if b, ok := system.ConstantValue(v); ok { + if !(b.IsUint64() && b.Uint64() <= 1) { + panic("MarkBoolean called a non-boolean constant") + } } system.mtBooleans[int(v.(compiled.Term))] = struct{}{} } From ddb20130ae90f90352b77d218583fc5d321913f7 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 28 Feb 2022 13:52:31 -0600 Subject: [PATCH 14/20] refactor: factorize coeff table initialization --- frontend/cs/coeff_table.go | 15 ++++++++++++++- frontend/cs/r1cs/compiler.go | 10 ---------- frontend/cs/scs/compiler.go | 13 ------------- 3 files changed, 14 insertions(+), 24 deletions(-) diff --git a/frontend/cs/coeff_table.go b/frontend/cs/coeff_table.go index dca31934af..b3fba3ddee 100644 --- a/frontend/cs/coeff_table.go +++ b/frontend/cs/coeff_table.go @@ -2,6 +2,8 @@ package cs import ( "math/big" + + "github.com/consensys/gnark/frontend/compiled" ) // CoeffTable helps build a constraint system but need not be serialized after compilation @@ -13,11 +15,22 @@ type CoeffTable struct { } func NewCoeffTable() CoeffTable { - return CoeffTable{ + st := CoeffTable{ Coeffs: make([]big.Int, 4), CoeffsIDsLarge: make(map[string]int), CoeffsIDsInt64: make(map[int64]int, 4), } + + st.Coeffs[compiled.CoeffIdZero].SetInt64(0) + st.Coeffs[compiled.CoeffIdOne].SetInt64(1) + st.Coeffs[compiled.CoeffIdTwo].SetInt64(2) + st.Coeffs[compiled.CoeffIdMinusOne].SetInt64(-1) + st.CoeffsIDsInt64[0] = compiled.CoeffIdZero + st.CoeffsIDsInt64[1] = compiled.CoeffIdOne + st.CoeffsIDsInt64[2] = compiled.CoeffIdTwo + st.CoeffsIDsInt64[-1] = compiled.CoeffIdMinusOne + + return st } // CoeffID tries to fetch the entry where b is if it exits, otherwise appends b to diff --git a/frontend/cs/r1cs/compiler.go b/frontend/cs/r1cs/compiler.go index 466b34aedd..1b1a1dedb8 100644 --- a/frontend/cs/r1cs/compiler.go +++ b/frontend/cs/r1cs/compiler.go @@ -76,16 +76,6 @@ func newCompiler(curveID ecc.ID, initialCapacity ...int) *compiler { mtBooleans: make(map[uint64][]compiled.LinearExpression), } - system.st.Coeffs[compiled.CoeffIdZero].SetInt64(0) - system.st.Coeffs[compiled.CoeffIdOne].SetInt64(1) - system.st.Coeffs[compiled.CoeffIdTwo].SetInt64(2) - system.st.Coeffs[compiled.CoeffIdMinusOne].SetInt64(-1) - - system.st.CoeffsIDsInt64[0] = compiled.CoeffIdZero - system.st.CoeffsIDsInt64[1] = compiled.CoeffIdOne - system.st.CoeffsIDsInt64[2] = compiled.CoeffIdTwo - system.st.CoeffsIDsInt64[-1] = compiled.CoeffIdMinusOne - system.Public = make([]string, 1) system.Secret = make([]string, 0) diff --git a/frontend/cs/scs/compiler.go b/frontend/cs/scs/compiler.go index 1215535d53..7992b31c50 100644 --- a/frontend/cs/scs/compiler.go +++ b/frontend/cs/scs/compiler.go @@ -75,19 +75,6 @@ func newCompiler(curveID ecc.ID, initialCapacity ...int) *compiler { st: cs.NewCoeffTable(), } - system.st.Coeffs[compiled.CoeffIdZero].SetInt64(0) - system.st.Coeffs[compiled.CoeffIdOne].SetInt64(1) - system.st.Coeffs[compiled.CoeffIdTwo].SetInt64(2) - system.st.Coeffs[compiled.CoeffIdMinusOne].SetInt64(-1) - - system.st.CoeffsIDsInt64[0] = compiled.CoeffIdZero - system.st.CoeffsIDsInt64[1] = compiled.CoeffIdOne - system.st.CoeffsIDsInt64[2] = compiled.CoeffIdTwo - system.st.CoeffsIDsInt64[-1] = compiled.CoeffIdMinusOne - - // system.public.variables = make([]Variable, 0) - // system.secret.variables = make([]Variable, 0) - // system.internal = make([]Variable, 0, capacity) system.Public = make([]string, 0) system.Secret = make([]string, 0) From 7658e9c1150722c7fc79a20eb03bfbec97e6b635 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 28 Feb 2022 13:56:50 -0600 Subject: [PATCH 15/20] perf: restored frontend.WithCapacity option... --- frontend/compile.go | 8 ++++---- frontend/cs/r1cs/compiler.go | 20 +++++++++----------- frontend/cs/r1cs/r1cs_test.go | 3 ++- frontend/cs/scs/compiler.go | 20 +++++++++----------- 4 files changed, 24 insertions(+), 27 deletions(-) diff --git a/frontend/compile.go b/frontend/compile.go index aad3cadb74..c6943367e9 100644 --- a/frontend/compile.go +++ b/frontend/compile.go @@ -18,7 +18,7 @@ type Builder interface { Compiler // Compile is called after circuit.Define() to produce a final IR (CompiledConstraintSystem) - Compile(opt CompileConfig) (CompiledConstraintSystem, error) + Compile() (CompiledConstraintSystem, error) // SetSchema is used internally by frontend.Compile to set the circuit schema SetSchema(*schema.Schema) @@ -81,7 +81,7 @@ type Compiler interface { Backend() backend.ID } -type NewCompiler func(ecc.ID) (Builder, error) +type NewCompiler func(ecc.ID, CompileConfig) (Builder, error) // Compile will generate a ConstraintSystem from the given circuit // @@ -111,7 +111,7 @@ func Compile(curveID ecc.ID, newCompiler NewCompiler, circuit Circuit, opts ...C } // instantiate new compiler - compiler, err := newCompiler(curveID) + compiler, err := newCompiler(curveID, opt) if err != nil { return nil, fmt.Errorf("new compiler: %w", err) } @@ -124,7 +124,7 @@ func Compile(curveID ecc.ID, newCompiler NewCompiler, circuit Circuit, opts ...C } // compile the circuit into its final form - return compiler.Compile(opt) + return compiler.Compile() } func parseCircuit(builder Builder, circuit Circuit) (err error) { diff --git a/frontend/cs/r1cs/compiler.go b/frontend/cs/r1cs/compiler.go index 1b1a1dedb8..f07dddfa40 100644 --- a/frontend/cs/r1cs/compiler.go +++ b/frontend/cs/r1cs/compiler.go @@ -44,15 +44,16 @@ import ( ) // NewCompiler returns a new R1CS compiler -func NewCompiler(curve ecc.ID) (frontend.Builder, error) { - return newCompiler(curve), nil +func NewCompiler(curve ecc.ID, config frontend.CompileConfig) (frontend.Builder, error) { + return newCompiler(curve, config), nil } type compiler struct { compiled.ConstraintSystem Constraints []compiled.R1C - st cs.CoeffTable + st cs.CoeffTable + config frontend.CompileConfig // map for recording boolean constrained variables (to not constrain them twice) mtBooleans map[uint64][]compiled.LinearExpression @@ -60,20 +61,17 @@ type compiler struct { // initialCapacity has quite some impact on frontend performance, especially on large circuits size // we may want to add build tags to tune that -func newCompiler(curveID ecc.ID, initialCapacity ...int) *compiler { - capacity := 0 - if len(initialCapacity) > 0 { - capacity = initialCapacity[0] - } +func newCompiler(curveID ecc.ID, config frontend.CompileConfig) *compiler { system := compiler{ ConstraintSystem: compiled.ConstraintSystem{ MDebug: make(map[int]int), MHints: make(map[int]*compiled.Hint), }, - Constraints: make([]compiled.R1C, 0, capacity), + Constraints: make([]compiled.R1C, 0, config.Capacity), st: cs.NewCoeffTable(), mtBooleans: make(map[uint64][]compiled.LinearExpression), + config: config, } system.Public = make([]string, 1) @@ -347,10 +345,10 @@ func init() { } // Compile constructs a rank-1 constraint sytem -func (cs *compiler) Compile(opt frontend.CompileConfig) (frontend.CompiledConstraintSystem, error) { +func (cs *compiler) Compile() (frontend.CompiledConstraintSystem, error) { // ensure all inputs and hints are constrained - if !opt.IgnoreUnconstrainedInputs { + if !cs.config.IgnoreUnconstrainedInputs { if err := cs.checkVariables(); err != nil { return nil, err } diff --git a/frontend/cs/r1cs/r1cs_test.go b/frontend/cs/r1cs/r1cs_test.go index 557f2fbb96..ca23706db2 100644 --- a/frontend/cs/r1cs/r1cs_test.go +++ b/frontend/cs/r1cs/r1cs_test.go @@ -21,6 +21,7 @@ import ( "testing" "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" ) @@ -50,7 +51,7 @@ func TestQuickSort(t *testing.T) { func TestReduce(t *testing.T) { - cs := newCompiler(ecc.BN254) + cs := newCompiler(ecc.BN254, frontend.CompileConfig{}) x := cs.newInternalVariable() y := cs.newInternalVariable() z := cs.newInternalVariable() diff --git a/frontend/cs/scs/compiler.go b/frontend/cs/scs/compiler.go index 7992b31c50..441fcf754e 100644 --- a/frontend/cs/scs/compiler.go +++ b/frontend/cs/scs/compiler.go @@ -43,15 +43,16 @@ import ( "github.com/consensys/gnark/internal/utils" ) -func NewCompiler(curve ecc.ID) (frontend.Builder, error) { - return newCompiler(curve), nil +func NewCompiler(curve ecc.ID, config frontend.CompileConfig) (frontend.Builder, error) { + return newCompiler(curve, config), nil } type compiler struct { compiled.ConstraintSystem Constraints []compiled.SparseR1C - st cs.CoeffTable + st cs.CoeffTable + config frontend.CompileConfig // map for recording boolean constrained variables (to not constrain them twice) mtBooleans map[int]struct{} @@ -59,11 +60,7 @@ type compiler struct { // initialCapacity has quite some impact on frontend performance, especially on large circuits size // we may want to add build tags to tune that -func newCompiler(curveID ecc.ID, initialCapacity ...int) *compiler { - capacity := 0 - if len(initialCapacity) > 0 { - capacity = initialCapacity[0] - } +func newCompiler(curveID ecc.ID, config frontend.CompileConfig) *compiler { system := compiler{ ConstraintSystem: compiled.ConstraintSystem{ @@ -71,8 +68,9 @@ func newCompiler(curveID ecc.ID, initialCapacity ...int) *compiler { MHints: make(map[int]*compiled.Hint), }, mtBooleans: make(map[int]struct{}), - Constraints: make([]compiled.SparseR1C, 0, capacity), + Constraints: make([]compiled.SparseR1C, 0, config.Capacity), st: cs.NewCoeffTable(), + config: config, } system.Public = make([]string, 0) @@ -296,10 +294,10 @@ func init() { tVariable = reflect.ValueOf(struct{ A frontend.Variable }{}).FieldByName("A").Type() } -func (cs *compiler) Compile(opt frontend.CompileConfig) (frontend.CompiledConstraintSystem, error) { +func (cs *compiler) Compile() (frontend.CompiledConstraintSystem, error) { // ensure all inputs and hints are constrained - if !opt.IgnoreUnconstrainedInputs { + if !cs.config.IgnoreUnconstrainedInputs { if err := cs.checkVariables(); err != nil { return nil, err } From bd1b05fcde638116ddd85e6f805987a3a1c6d433 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 28 Feb 2022 14:02:02 -0600 Subject: [PATCH 16/20] docs: added Deprecated comments in front of APIs moved to Compiler interface --- frontend/api.go | 42 +++++++++++++++++++++++ frontend/compile.go | 73 --------------------------------------- frontend/compiler.go | 81 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 123 insertions(+), 73 deletions(-) create mode 100644 frontend/compiler.go diff --git a/frontend/api.go b/frontend/api.go index 7a184f4678..7f79decb97 100644 --- a/frontend/api.go +++ b/frontend/api.go @@ -16,6 +16,14 @@ limitations under the License. package frontend +import ( + "math/big" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/backend/hint" +) + // API represents the available functions to circuit developers type API interface { // --------------------------------------------------------------------------------------------- @@ -106,4 +114,38 @@ type API interface { // Compiler returns the compiler object for advanced circuit development Compiler() Compiler + + // Deprecated APIs + + // MarkBoolean is a shorcut to api.Compiler().MarkBoolean() + // Deprecated: use api.Compiler().MarkBoolean() instead + MarkBoolean(v Variable) + + // IsBoolean is a shorcut to api.Compiler().IsBoolean() + // Deprecated: use api.Compiler().IsBoolean() instead + IsBoolean(v Variable) bool + + // NewHint is a shorcut to api.Compiler().NewHint() + // Deprecated: use api.Compiler().NewHint() instead + NewHint(f hint.Function, nbOutputs int, inputs ...Variable) ([]Variable, error) + + // Tag is a shorcut to api.Compiler().Tag() + // Deprecated: use api.Compiler().Tag() instead + Tag(name string) Tag + + // AddCounter is a shorcut to api.Compiler().AddCounter() + // Deprecated: use api.Compiler().AddCounter() instead + AddCounter(from, to Tag) + + // ConstantValue is a shorcut to api.Compiler().ConstantValue() + // Deprecated: use api.Compiler().ConstantValue() instead + ConstantValue(v Variable) (*big.Int, bool) + + // Curve is a shorcut to api.Compiler().Curve() + // Deprecated: use api.Compiler().Curve() instead + Curve() ecc.ID + + // Backend is a shorcut to api.Compiler().Backend() + // Deprecated: use api.Compiler().Backend() instead + Backend() backend.ID } diff --git a/frontend/compile.go b/frontend/compile.go index c6943367e9..f934db4c98 100644 --- a/frontend/compile.go +++ b/frontend/compile.go @@ -3,86 +3,13 @@ package frontend import ( "errors" "fmt" - "math/big" "reflect" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/debug" "github.com/consensys/gnark/frontend/schema" ) -type Builder interface { - API - Compiler - - // Compile is called after circuit.Define() to produce a final IR (CompiledConstraintSystem) - Compile() (CompiledConstraintSystem, error) - - // SetSchema is used internally by frontend.Compile to set the circuit schema - SetSchema(*schema.Schema) - - // AddPublicVariable is called by the compiler when parsing the circuit schema. It panics if - // called inside circuit.Define() - AddPublicVariable(name string) Variable - - // AddSecretVariable is called by the compiler when parsing the circuit schema. It panics if - // called inside circuit.Define() - AddSecretVariable(name string) Variable -} - -// Compiler represents a constraint system compiler -type Compiler interface { - // MarkBoolean sets (but do not constraint!) v to be boolean - // This is useful in scenarios where a variable is known to be boolean through a constraint - // that is not api.AssertIsBoolean. If v is a constant, this is a no-op. - MarkBoolean(v Variable) - - // IsBoolean returns true if given variable was marked as boolean in the compiler (see MarkBoolean) - // Use with care; variable may not have been **constrained** to be boolean - // This returns true if the v is a constant and v == 0 || v == 1. - IsBoolean(v Variable) bool - - // NewHint initializes internal variables whose value will be evaluated - // using the provided hint function at run time from the inputs. Inputs must - // be either variables or convertible to *big.Int. The function returns an - // error if the number of inputs is not compatible with f. - // - // The hint function is provided at the proof creation time and is not - // embedded into the circuit. From the backend point of view, the variable - // returned by the hint function is equivalent to the user-supplied witness, - // but its actual value is assigned by the solver, not the caller. - // - // No new constraints are added to the newly created wire and must be added - // manually in the circuit. Failing to do so leads to solver failure. - // - // If nbOutputs is specified, it must be >= 1 and <= f.NbOutputs - NewHint(f hint.Function, nbOutputs int, inputs ...Variable) ([]Variable, error) - - // Tag creates a tag at a given place in a circuit. The state of the tag may contain informations needed to - // measure constraints, variables and coefficients creations through AddCounter - Tag(name string) Tag - - // AddCounter measures the number of constraints, variables and coefficients created between two tags - // note that the PlonK statistics are contextual since there is a post-compile phase where linear expressions - // are factorized. That is, measuring 2 times the "repeating" piece of circuit may give less constraints the second time - AddCounter(from, to Tag) - - // ConstantValue returns the big.Int value of v and true if op is a success. - // nil and false if failure. This API returns a boolean to allow for future refactoring - // replacing *big.Int with fr.Element - ConstantValue(v Variable) (*big.Int, bool) - - // CurveID returns the ecc.ID injected by the compiler - Curve() ecc.ID - - // Backend returns the backend.ID injected by the compiler - Backend() backend.ID -} - -type NewCompiler func(ecc.ID, CompileConfig) (Builder, error) - // Compile will generate a ConstraintSystem from the given circuit // // 1. it will first allocate the user inputs (see type Tag for more info) diff --git a/frontend/compiler.go b/frontend/compiler.go new file mode 100644 index 0000000000..866e9bc417 --- /dev/null +++ b/frontend/compiler.go @@ -0,0 +1,81 @@ +package frontend + +import ( + "math/big" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/frontend/schema" +) + +type NewCompiler func(ecc.ID, CompileConfig) (Builder, error) + +// Compiler represents a constraint system compiler +type Compiler interface { + // MarkBoolean sets (but do not constraint!) v to be boolean + // This is useful in scenarios where a variable is known to be boolean through a constraint + // that is not api.AssertIsBoolean. If v is a constant, this is a no-op. + MarkBoolean(v Variable) + + // IsBoolean returns true if given variable was marked as boolean in the compiler (see MarkBoolean) + // Use with care; variable may not have been **constrained** to be boolean + // This returns true if the v is a constant and v == 0 || v == 1. + IsBoolean(v Variable) bool + + // NewHint initializes internal variables whose value will be evaluated + // using the provided hint function at run time from the inputs. Inputs must + // be either variables or convertible to *big.Int. The function returns an + // error if the number of inputs is not compatible with f. + // + // The hint function is provided at the proof creation time and is not + // embedded into the circuit. From the backend point of view, the variable + // returned by the hint function is equivalent to the user-supplied witness, + // but its actual value is assigned by the solver, not the caller. + // + // No new constraints are added to the newly created wire and must be added + // manually in the circuit. Failing to do so leads to solver failure. + // + // If nbOutputs is specified, it must be >= 1 and <= f.NbOutputs + NewHint(f hint.Function, nbOutputs int, inputs ...Variable) ([]Variable, error) + + // Tag creates a tag at a given place in a circuit. The state of the tag may contain informations needed to + // measure constraints, variables and coefficients creations through AddCounter + Tag(name string) Tag + + // AddCounter measures the number of constraints, variables and coefficients created between two tags + // note that the PlonK statistics are contextual since there is a post-compile phase where linear expressions + // are factorized. That is, measuring 2 times the "repeating" piece of circuit may give less constraints the second time + AddCounter(from, to Tag) + + // ConstantValue returns the big.Int value of v and true if op is a success. + // nil and false if failure. This API returns a boolean to allow for future refactoring + // replacing *big.Int with fr.Element + ConstantValue(v Variable) (*big.Int, bool) + + // CurveID returns the ecc.ID injected by the compiler + Curve() ecc.ID + + // Backend returns the backend.ID injected by the compiler + Backend() backend.ID +} + +// Builder represents a constraint system builder +type Builder interface { + API + Compiler + + // Compile is called after circuit.Define() to produce a final IR (CompiledConstraintSystem) + Compile() (CompiledConstraintSystem, error) + + // SetSchema is used internally by frontend.Compile to set the circuit schema + SetSchema(*schema.Schema) + + // AddPublicVariable is called by the compiler when parsing the circuit schema. It panics if + // called inside circuit.Define() + AddPublicVariable(name string) Variable + + // AddSecretVariable is called by the compiler when parsing the circuit schema. It panics if + // called inside circuit.Define() + AddSecretVariable(name string) Variable +} From 68a3bb3233ae2bc563aaec714a7feb3b33a82e10 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 28 Feb 2022 14:14:35 -0600 Subject: [PATCH 17/20] refactor: compiled.Variable -> compiled.LinearExpression --- frontend/compiled/cs.go | 6 +- frontend/compiled/hint.go | 9 -- frontend/compiled/log.go | 8 +- frontend/compiled/r1cs.go | 2 +- frontend/compiled/variable.go | 24 ++--- frontend/cs/r1cs/api.go | 58 +++++----- frontend/cs/r1cs/api_assertions.go | 16 +-- frontend/cs/r1cs/compiler.go | 102 ++++++++---------- frontend/cs/r1cs/r1cs_test.go | 4 +- frontend/cs/scs/compiler.go | 16 ++- internal/backend/bls12-377/cs/r1cs.go | 14 +-- internal/backend/bls12-377/cs/solution.go | 3 - internal/backend/bls12-377/groth16/setup.go | 10 +- internal/backend/bls12-381/cs/r1cs.go | 14 +-- internal/backend/bls12-381/cs/solution.go | 3 - internal/backend/bls12-381/groth16/setup.go | 10 +- internal/backend/bls24-315/cs/r1cs.go | 14 +-- internal/backend/bls24-315/cs/solution.go | 3 - internal/backend/bls24-315/groth16/setup.go | 10 +- internal/backend/bn254/cs/r1cs.go | 14 +-- internal/backend/bn254/cs/solution.go | 3 - internal/backend/bn254/groth16/setup.go | 10 +- internal/backend/bw6-633/cs/r1cs.go | 14 +-- internal/backend/bw6-633/cs/solution.go | 3 - internal/backend/bw6-633/groth16/setup.go | 10 +- internal/backend/bw6-761/cs/r1cs.go | 14 +-- internal/backend/bw6-761/cs/solution.go | 3 - internal/backend/bw6-761/groth16/setup.go | 10 +- .../template/representations/r1cs.go.tmpl | 14 +-- .../template/representations/solution.go.tmpl | 3 - .../zkpschemes/groth16/groth16.setup.go.tmpl | 10 +- 31 files changed, 189 insertions(+), 245 deletions(-) diff --git a/frontend/compiled/cs.go b/frontend/compiled/cs.go index a2070579af..ea6a87d2fe 100644 --- a/frontend/compiled/cs.go +++ b/frontend/compiled/cs.go @@ -90,12 +90,12 @@ func (cs *ConstraintSystem) AddDebugInfo(errName string, i ...interface{}) int { for _, _i := range i { switch v := _i.(type) { - case Variable: - if len(v.LinExp) > 1 { + case LinearExpression: + if len(v) > 1 { sbb.WriteString("(") } l.WriteVariable(v, &sbb) - if len(v.LinExp) > 1 { + if len(v) > 1 { sbb.WriteString(")") } case string: diff --git a/frontend/compiled/hint.go b/frontend/compiled/hint.go index 8aeee12208..90ef4c8caa 100644 --- a/frontend/compiled/hint.go +++ b/frontend/compiled/hint.go @@ -24,9 +24,6 @@ func (h Hint) inputsCBORTags() (cbor.TagSet, error) { if err := tags.Add(defTagOpts, reflect.TypeOf(LinearExpression{}), 25443); err != nil { return nil, fmt.Errorf("new LE tag: %w", err) } - if err := tags.Add(defTagOpts, reflect.TypeOf(Variable{}), 25444); err != nil { - return nil, fmt.Errorf("new variable tag: %w", err) - } if err := tags.Add(defTagOpts, reflect.TypeOf(Term(0)), 25445); err != nil { return nil, fmt.Errorf("new term tag: %w", err) } @@ -100,12 +97,6 @@ func (h *Hint) UnmarshalCBOR(b []byte) error { return fmt.Errorf("unmarshal linear expression: %w", err) } inputs[i] = LinearExpression(v) - case 25444: - var v Variable - if err := dec.Unmarshal(vin.Content, &v); err != nil { - return fmt.Errorf("unmarshal variable: %w", err) - } - inputs[i] = v case 25445: var v uint64 if err := dec.Unmarshal(vin.Content, &v); err != nil { diff --git a/frontend/compiled/log.go b/frontend/compiled/log.go index a578b80b92..b25676e006 100644 --- a/frontend/compiled/log.go +++ b/frontend/compiled/log.go @@ -28,14 +28,14 @@ type LogEntry struct { ToResolve []Term } -func (l *LogEntry) WriteVariable(le Variable, sbb *strings.Builder) { - sbb.Grow(len(le.LinExp) * len(" + (xx + xxxxxxxxxxxx")) +func (l *LogEntry) WriteVariable(le LinearExpression, sbb *strings.Builder) { + sbb.Grow(len(le) * len(" + (xx + xxxxxxxxxxxx")) - for i := 0; i < len(le.LinExp); i++ { + for i := 0; i < len(le); i++ { if i > 0 { sbb.WriteString(" + ") } - l.WriteTerm(le.LinExp[i], sbb) + l.WriteTerm(le[i], sbb) } } diff --git a/frontend/compiled/r1cs.go b/frontend/compiled/r1cs.go index f2645385b7..2bf3a0d8a6 100644 --- a/frontend/compiled/r1cs.go +++ b/frontend/compiled/r1cs.go @@ -32,7 +32,7 @@ func (r1cs *R1CS) GetNbConstraints() int { // R1C used to compute the wires type R1C struct { - L, R, O Variable + L, R, O LinearExpression } func (r1c *R1C) String(coeffs []big.Int) string { diff --git a/frontend/compiled/variable.go b/frontend/compiled/variable.go index 4c3e2200a3..f4d329fdcf 100644 --- a/frontend/compiled/variable.go +++ b/frontend/compiled/variable.go @@ -23,23 +23,17 @@ import ( // errNoValue triggered when trying to access a variable that was not allocated var errNoValue = errors.New("can't determine API input value") -// Variable represent a linear expression of wires -type Variable struct { - LinExp LinearExpression -} - // Clone returns a copy of the underlying slice -func (v Variable) Clone() Variable { - var res Variable - res.LinExp = make([]Term, len(v.LinExp)) - copy(res.LinExp, v.LinExp) +func (v LinearExpression) Clone() LinearExpression { + res := make(LinearExpression, len(v)) + copy(res, v) return res } -func (v Variable) string(sbb *strings.Builder, coeffs []big.Int) { - for i := 0; i < len(v.LinExp); i++ { - v.LinExp[i].string(sbb, coeffs) - if i+1 < len(v.LinExp) { +func (v LinearExpression) string(sbb *strings.Builder, coeffs []big.Int) { + for i := 0; i < len(v); i++ { + v[i].string(sbb, coeffs) + if i+1 < len(v) { sbb.WriteString(" + ") } } @@ -50,9 +44,9 @@ func (v Variable) string(sbb *strings.Builder, coeffs []big.Int) { // var a variable // cs.Mul(a, 1) // since a was not in the circuit struct it is not a secret variable -func (v Variable) AssertIsSet() { +func (v LinearExpression) AssertIsSet() { - if len(v.LinExp) == 0 { + if len(v) == 0 { panic(errNoValue) } diff --git a/frontend/cs/r1cs/api.go b/frontend/cs/r1cs/api.go index cc6287412c..c0be62114e 100644 --- a/frontend/cs/r1cs/api.go +++ b/frontend/cs/r1cs/api.go @@ -41,11 +41,11 @@ func (system *compiler) Add(i1, i2 frontend.Variable, in ...frontend.Variable) f vars, s := system.toVariables(append([]frontend.Variable{i1, i2}, in...)...) // allocate resulting frontend.Variable - res := compiled.Variable{LinExp: make([]compiled.Term, 0, s)} + res := make(compiled.LinearExpression, 0, s) for _, v := range vars { l := v.Clone() - res.LinExp = append(res.LinExp, l.LinExp...) + res = append(res, l...) } res = system.reduce(res) @@ -62,9 +62,7 @@ func (system *compiler) Neg(i frontend.Variable) frontend.Variable { return system.toVariable(n) } - res := compiled.Variable{LinExp: system.negateLinExp(vars[0].LinExp)} - - return res + return system.negateLinExp(vars[0]) } // Sub returns res = i1 - i2 @@ -74,15 +72,13 @@ func (system *compiler) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) f vars, s := system.toVariables(append([]frontend.Variable{i1, i2}, in...)...) // allocate resulting frontend.Variable - res := compiled.Variable{ - LinExp: make([]compiled.Term, 0, s), - } + res := make(compiled.LinearExpression, 0, s) c := vars[0].Clone() - res.LinExp = append(res.LinExp, c.LinExp...) + res = append(res, c...) for i := 1; i < len(vars); i++ { - negLinExp := system.negateLinExp(vars[i].LinExp) - res.LinExp = append(res.LinExp, negLinExp...) + negLinExp := system.negateLinExp(vars[i]) + res = append(res, negLinExp...) } // reduce linear expression @@ -95,7 +91,7 @@ func (system *compiler) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) f func (system *compiler) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { vars, _ := system.toVariables(append([]frontend.Variable{i1, i2}, in...)...) - mul := func(v1, v2 compiled.Variable) compiled.Variable { + mul := func(v1, v2 compiled.LinearExpression) compiled.LinearExpression { n1, v1Constant := system.ConstantValue(v1) n2, v2Constant := system.ConstantValue(v2) @@ -110,7 +106,7 @@ func (system *compiler) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) f // v1 and v2 are constants, we multiply big.Int values and return resulting constant if v1Constant && v2Constant { n1.Mul(n1, n2).Mod(n1, system.CurveID.Info().Fr.Modulus()) - return system.toVariable(n1).(compiled.Variable) + return system.toVariable(n1).(compiled.LinearExpression) } // ensure v2 is the constant @@ -130,13 +126,13 @@ func (system *compiler) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) f return res } -func (system *compiler) mulConstant(v1, constant compiled.Variable) compiled.Variable { +func (system *compiler) mulConstant(v1, constant compiled.LinearExpression) compiled.LinearExpression { // multiplying a frontend.Variable by a constant -> we updated the coefficients in the linear expression // leading to that frontend.Variable res := v1.Clone() lambda, _ := system.ConstantValue(constant) - for i, t := range v1.LinExp { + for i, t := range v1 { cID, vID, visibility := t.Unpack() var newCoeff big.Int switch cID { @@ -152,7 +148,7 @@ func (system *compiler) mulConstant(v1, constant compiled.Variable) compiled.Var coeff := system.st.Coeffs[cID] newCoeff.Mul(&coeff, lambda) } - res.LinExp[i] = compiled.Pack(vID, system.st.CoeffID(&newCoeff), visibility) + res[i] = compiled.Pack(vID, system.st.CoeffID(&newCoeff), visibility) } return res } @@ -187,7 +183,7 @@ func (system *compiler) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable } // v1 is not constant - return system.mulConstant(v1, system.toVariable(n2).(compiled.Variable)) + return system.mulConstant(v1, system.toVariable(n2).(compiled.LinearExpression)) } // Div returns res = i1 / i2 @@ -223,7 +219,7 @@ func (system *compiler) Div(i1, i2 frontend.Variable) frontend.Variable { } // v1 is not constant - return system.mulConstant(v1, system.toVariable(n2).(compiled.Variable)) + return system.mulConstant(v1, system.toVariable(n2).(compiled.LinearExpression)) } // Inverse returns res = inverse(v) @@ -283,7 +279,7 @@ func (system *compiler) ToBinary(i1 frontend.Variable, n ...int) []frontend.Vari } // toBinary is equivalent to ToBinary, exept the returned bits are NOT boolean constrained. -func (system *compiler) toBinary(a compiled.Variable, nbBits int, unsafe bool) []frontend.Variable { +func (system *compiler) toBinary(a compiled.LinearExpression, nbBits int, unsafe bool) []frontend.Variable { if _, ok := system.ConstantValue(a); ok { return system.ToBinary(a, nbBits) @@ -310,7 +306,7 @@ func (system *compiler) toBinary(a compiled.Variable, nbBits int, unsafe bool) [ } } - //var Σbi compiled.Variable + //var Σbi compiled.LinearExpression var Σbi frontend.Variable if nbBits == 1 { system.AssertIsEqual(sb[0], a) @@ -343,7 +339,7 @@ func (system *compiler) FromBinary(_b ...frontend.Variable) frontend.Variable { var c big.Int c.SetUint64(1) - L := make([]compiled.Term, len(b)) + L := make(compiled.LinearExpression, len(b)) for i := 0; i < len(L); i++ { v = system.Mul(c, b[i]) // no constraint is recorded res = system.Add(v, res) // no constraint is recorded @@ -368,8 +364,8 @@ func (system *compiler) Xor(_a, _b frontend.Variable) frontend.Variable { // the formulation used is for easing up the conversion to sparse r1cs res := system.newInternalVariable() system.MarkBoolean(res) - c := system.Neg(res).(compiled.Variable) - c.LinExp = append(c.LinExp, a.LinExp[0], b.LinExp[0]) + c := system.Neg(res).(compiled.LinearExpression) + c = append(c, a[0], b[0]) aa := system.Mul(a, 2) system.Constraints = append(system.Constraints, newR1C(aa, b, c)) @@ -389,8 +385,8 @@ func (system *compiler) Or(_a, _b frontend.Variable) frontend.Variable { // the formulation used is for easing up the conversion to sparse r1cs res := system.newInternalVariable() system.MarkBoolean(res) - c := system.Neg(res).(compiled.Variable) - c.LinExp = append(c.LinExp, a.LinExp[0], b.LinExp[0]) + c := system.Neg(res).(compiled.LinearExpression) + c = append(c, a[0], b[0]) system.Constraints = append(system.Constraints, newR1C(a, b, c)) return res @@ -559,14 +555,14 @@ func (system *compiler) Println(a ...frontend.Variable) { if i > 0 { sbb.WriteByte(' ') } - if v, ok := arg.(compiled.Variable); ok { + if v, ok := arg.(compiled.LinearExpression); ok { v.AssertIsSet() sbb.WriteString("%s") // we set limits to the linear expression, so that the log printer // can evaluate it before printing it log.ToResolve = append(log.ToResolve, compiled.TermDelimitor) - log.ToResolve = append(log.ToResolve, v.LinExp...) + log.ToResolve = append(log.ToResolve, v...) log.ToResolve = append(log.ToResolve, compiled.TermDelimitor) } else { printArg(&log, &sbb, arg) @@ -606,11 +602,11 @@ func printArg(log *compiled.LogEntry, sbb *strings.Builder, a frontend.Variable) sbb.WriteString(", ") } - v := tValue.Interface().(compiled.Variable) + v := tValue.Interface().(compiled.LinearExpression) // we set limits to the linear expression, so that the log printer // can evaluate it before printing it log.ToResolve = append(log.ToResolve, compiled.TermDelimitor) - log.ToResolve = append(log.ToResolve, v.LinExp...) + log.ToResolve = append(log.ToResolve, v...) log.ToResolve = append(log.ToResolve, compiled.TermDelimitor) return nil } @@ -620,8 +616,8 @@ func printArg(log *compiled.LogEntry, sbb *strings.Builder, a frontend.Variable) } // returns -le, the result is a copy -func (system *compiler) negateLinExp(l []compiled.Term) []compiled.Term { - res := make([]compiled.Term, len(l)) +func (system *compiler) negateLinExp(l compiled.LinearExpression) compiled.LinearExpression { + res := make(compiled.LinearExpression, len(l)) var lambda big.Int for i, t := range l { cID, vID, visibility := t.Unpack() diff --git a/frontend/cs/r1cs/api_assertions.go b/frontend/cs/r1cs/api_assertions.go index ea2f4f630c..fae3e601f8 100644 --- a/frontend/cs/r1cs/api_assertions.go +++ b/frontend/cs/r1cs/api_assertions.go @@ -28,8 +28,8 @@ import ( // AssertIsEqual adds an assertion in the constraint system (i1 == i2) func (system *compiler) AssertIsEqual(i1, i2 frontend.Variable) { // encoded 1 * i1 == i2 - r := system.toVariable(i1).(compiled.Variable) - o := system.toVariable(i2).(compiled.Variable) + r := system.toVariable(i1).(compiled.LinearExpression) + o := system.toVariable(i2).(compiled.LinearExpression) debug := system.AddDebugInfo("assertIsEqual", r, " == ", o) @@ -55,7 +55,7 @@ func (system *compiler) AssertIsBoolean(i1 frontend.Variable) { } if system.IsBoolean(v) { - return // compiled.Variable is already constrained + return // compiled.LinearExpression is already constrained } system.MarkBoolean(v) @@ -78,7 +78,7 @@ func (system *compiler) AssertIsLessOrEqual(_v frontend.Variable, bound frontend v, _ := system.toVariables(_v) switch b := bound.(type) { - case compiled.Variable: + case compiled.LinearExpression: b.AssertIsSet() system.mustBeLessOrEqVar(v[0], b) default: @@ -87,7 +87,7 @@ func (system *compiler) AssertIsLessOrEqual(_v frontend.Variable, bound frontend } -func (system *compiler) mustBeLessOrEqVar(a, bound compiled.Variable) { +func (system *compiler) mustBeLessOrEqVar(a, bound compiled.LinearExpression) { debug := system.AddDebugInfo("mustBeLessOrEq", a, " <= ", bound) nbBits := system.BitLen() @@ -121,14 +121,14 @@ func (system *compiler) mustBeLessOrEqVar(a, bound compiled.Variable) { // note if bound[i] == 1, this constraint is (1 - ai) * ai == 0 // → this is a boolean constraint // if bound[i] == 0, t must be 0 or 1, thus ai must be 0 or 1 too - system.MarkBoolean(aBits[i].(compiled.Variable)) // this does not create a constraint + system.MarkBoolean(aBits[i].(compiled.LinearExpression)) // this does not create a constraint system.addConstraint(newR1C(l, aBits[i], zero), debug) } } -func (system *compiler) mustBeLessOrEqCst(a compiled.Variable, bound big.Int) { +func (system *compiler) mustBeLessOrEqCst(a compiled.LinearExpression, bound big.Int) { nbBits := system.BitLen() @@ -175,7 +175,7 @@ func (system *compiler) mustBeLessOrEqCst(a compiled.Variable, bound big.Int) { l = system.Sub(l, aBits[i]) system.addConstraint(newR1C(l, aBits[i], system.toVariable(0)), debug) - system.MarkBoolean(aBits[i].(compiled.Variable)) + system.MarkBoolean(aBits[i].(compiled.LinearExpression)) } else { system.AssertIsBoolean(aBits[i]) } diff --git a/frontend/cs/r1cs/compiler.go b/frontend/cs/r1cs/compiler.go index f07dddfa40..74a88f5fef 100644 --- a/frontend/cs/r1cs/compiler.go +++ b/frontend/cs/r1cs/compiler.go @@ -87,11 +87,11 @@ func newCompiler(curveID ecc.ID, config frontend.CompileConfig) *compiler { // newInternalVariable creates a new wire, appends it on the list of wires of the circuit, sets // the wire's id to the number of wires, and returns it -func (system *compiler) newInternalVariable() compiled.Variable { +func (system *compiler) newInternalVariable() compiled.LinearExpression { idx := system.NbInternalVariables system.NbInternalVariables++ - return compiled.Variable{ - LinExp: compiled.LinearExpression{compiled.Pack(idx, compiled.CoeffIdOne, schema.Internal)}, + return compiled.LinearExpression{ + compiled.Pack(idx, compiled.CoeffIdOne, schema.Internal), } } @@ -102,10 +102,9 @@ func (system *compiler) AddPublicVariable(name string) frontend.Variable { } idx := len(system.Public) system.Public = append(system.Public, name) - res := compiled.Variable{ - LinExp: compiled.LinearExpression{compiled.Pack(idx, compiled.CoeffIdOne, schema.Public)}, + return compiled.LinearExpression{ + compiled.Pack(idx, compiled.CoeffIdOne, schema.Public), } - return res } // AddSecretVariable creates a new secret Variable @@ -115,15 +114,14 @@ func (system *compiler) AddSecretVariable(name string) frontend.Variable { } idx := len(system.Secret) system.Secret = append(system.Secret, name) - res := compiled.Variable{ - LinExp: compiled.LinearExpression{compiled.Pack(idx, compiled.CoeffIdOne, schema.Secret)}, + return compiled.LinearExpression{ + compiled.Pack(idx, compiled.CoeffIdOne, schema.Secret), } - return res } -func (system *compiler) one() compiled.Variable { - return compiled.Variable{ - LinExp: compiled.LinearExpression{compiled.Pack(0, compiled.CoeffIdOne, schema.Public)}, +func (system *compiler) one() compiled.LinearExpression { + return compiled.LinearExpression{ + compiled.Pack(0, compiled.CoeffIdOne, schema.Public), } } @@ -131,23 +129,23 @@ func (system *compiler) one() compiled.Variable { // It factorizes Variable that appears multiple times with != coeff Ids // To ensure the determinism in the compile process, Variables are stored as public∥secret∥internal∥unset // for each visibility, the Variables are sorted from lowest ID to highest ID -func (system *compiler) reduce(l compiled.Variable) compiled.Variable { +func (system *compiler) reduce(l compiled.LinearExpression) compiled.LinearExpression { // ensure our linear expression is sorted, by visibility and by Variable ID - if !sort.IsSorted(l.LinExp) { // may not help - sort.Sort(l.LinExp) + if !sort.IsSorted(l) { // may not help + sort.Sort(l) } mod := system.CurveID.Info().Fr.Modulus() c := new(big.Int) - for i := 1; i < len(l.LinExp); i++ { - pcID, pvID, pVis := l.LinExp[i-1].Unpack() - ccID, cvID, cVis := l.LinExp[i].Unpack() + for i := 1; i < len(l); i++ { + pcID, pvID, pVis := l[i-1].Unpack() + ccID, cvID, cVis := l[i].Unpack() if pVis == cVis && pvID == cvID { // we have redundancy c.Add(&system.st.Coeffs[pcID], &system.st.Coeffs[ccID]) c.Mod(c, mod) - l.LinExp[i-1].SetCoeffID(system.st.CoeffID(c)) - l.LinExp = append(l.LinExp[:i], l.LinExp[i+1:]...) + l[i-1].SetCoeffID(system.st.CoeffID(c)) + l = append(l[:i], l[i+1:]...) i-- } } @@ -157,9 +155,9 @@ func (system *compiler) reduce(l compiled.Variable) compiled.Variable { // newR1C clones the linear expression associated with the Variables (to avoid offseting the ID multiple time) // and return a R1C func newR1C(_l, _r, _o frontend.Variable) compiled.R1C { - l := _l.(compiled.Variable) - r := _r.(compiled.Variable) - o := _o.(compiled.Variable) + l := _l.(compiled.LinearExpression) + r := _r.(compiled.LinearExpression) + o := _o.(compiled.LinearExpression) // interestingly, this is key to groth16 performance. // l * r == r * l == o @@ -167,7 +165,7 @@ func newR1C(_l, _r, _o frontend.Variable) compiled.R1C { // the "r" linear expression is going to end up in the B matrix // the less Variable we have appearing in the B matrix, the more likely groth16.Setup // is going to produce infinity points in pk.G1.B and pk.G2.B, which will speed up proving time - if len(l.LinExp) > len(r.LinExp) { + if len(l) > len(r) { l, r = r, l } @@ -199,7 +197,7 @@ func (system *compiler) MarkBoolean(v frontend.Variable) { return } // v is a linear expression - l := v.(compiled.Variable).LinExp + l := v.(compiled.LinearExpression) if !sort.IsSorted(l) { sort.Sort(l) } @@ -218,7 +216,7 @@ func (system *compiler) IsBoolean(v frontend.Variable) bool { return b.IsUint64() && b.Uint64() <= 1 } // v is a linear expression - l := v.(compiled.Variable).LinExp + l := v.(compiled.LinearExpression) if !sort.IsSorted(l) { sort.Sort(l) } @@ -258,8 +256,8 @@ func (system *compiler) checkVariables() error { mHintsConstrained := make(map[int]bool) // for each constraint, we check the linear expressions and mark our inputs / hints as constrained - processLinearExpression := func(l compiled.Variable) { - for _, t := range l.LinExp { + processLinearExpression := func(l compiled.LinearExpression) { + for _, t := range l { if t.CoeffID() == compiled.CoeffIdZero { // ignore zero coefficient, as it does not constraint the Variable // though, we may want to flag that IF the Variable doesn't appear else where @@ -390,9 +388,9 @@ func (cs *compiler) Compile() (frontend.CompiledConstraintSystem, error) { } for i := 0; i < len(res.Constraints); i++ { - offsetIDs(res.Constraints[i].L.LinExp) - offsetIDs(res.Constraints[i].R.LinExp) - offsetIDs(res.Constraints[i].O.LinExp) + offsetIDs(res.Constraints[i].L) + offsetIDs(res.Constraints[i].R) + offsetIDs(res.Constraints[i].O) } // we need to offset the ids in the hints @@ -414,11 +412,6 @@ HINTLOOP: copy(inputs, hint.Inputs) for j := 0; j < len(inputs); j++ { switch t := inputs[j].(type) { - case compiled.Variable: - tmp := make(compiled.LinearExpression, len(t.LinExp)) - copy(tmp, t.LinExp) - offsetIDs(tmp) - inputs[j] = tmp case compiled.LinearExpression: tmp := make(compiled.LinearExpression, len(t)) copy(tmp, t) @@ -496,9 +489,9 @@ func buildLevels(ccs compiled.R1CS) [][]int { b.nodeLevel = 0 - b.processLE(c.L.LinExp, cID) - b.processLE(c.R.LinExp, cID) - b.processLE(c.O.LinExp, cID) + b.processLE(c.L, cID) + b.processLE(c.R, cID) + b.processLE(c.O, cID) b.nodeLevels[cID] = b.nodeLevel b.mLevels[b.nodeLevel]++ @@ -554,8 +547,6 @@ func (b *levelBuilder) processLE(l compiled.LinearExpression, cID int) { for _, in := range h.Inputs { switch t := in.(type) { - case compiled.Variable: - b.processLE(t.LinExp, cID) case compiled.LinearExpression: b.processLE(t, cID) case compiled.Term: @@ -577,13 +568,13 @@ func (b *levelBuilder) processLE(l compiled.LinearExpression, cID int) { // ConstantValue returns the big.Int value of v. // Will panic if v.IsConstant() == false func (system *compiler) ConstantValue(v frontend.Variable) (*big.Int, bool) { - if _v, ok := v.(compiled.Variable); ok { + if _v, ok := v.(compiled.LinearExpression); ok { _v.AssertIsSet() - if len(_v.LinExp) != 1 { + if len(_v) != 1 { return nil, false } - cID, vID, visibility := _v.LinExp[0].Unpack() + cID, vID, visibility := _v[0].Unpack() if !(vID == 0 && visibility == schema.Public) { return nil, false } @@ -597,14 +588,14 @@ func (system *compiler) Backend() backend.ID { return backend.GROTH16 } -// toVariable will return (and allocate if neccesary) a compiled.Variable from given value +// toVariable will return (and allocate if neccesary) a compiled.LinearExpression from given value // -// if input is already a compiled.Variable, does nothing -// else, attempts to convert input to a big.Int (see utils.FromInterface) and returns a toVariable compiled.Variable +// if input is already a compiled.LinearExpression, does nothing +// else, attempts to convert input to a big.Int (see utils.FromInterface) and returns a toVariable compiled.LinearExpression func (system *compiler) toVariable(input interface{}) frontend.Variable { switch t := input.(type) { - case compiled.Variable: + case compiled.LinearExpression: t.AssertIsSet() return t default: @@ -613,19 +604,19 @@ func (system *compiler) toVariable(input interface{}) frontend.Variable { return system.one() } r := system.one() - r.LinExp[0] = system.setCoeff(r.LinExp[0], &n) + r[0] = system.setCoeff(r[0], &n) return r } } // toVariables return frontend.Variable corresponding to inputs and the total size of the linear expressions -func (system *compiler) toVariables(in ...frontend.Variable) ([]compiled.Variable, int) { - r := make([]compiled.Variable, 0, len(in)) +func (system *compiler) toVariables(in ...frontend.Variable) ([]compiled.LinearExpression, int) { + r := make([]compiled.LinearExpression, 0, len(in)) s := 0 e := func(i frontend.Variable) { - v := system.toVariable(i).(compiled.Variable) + v := system.toVariable(i).(compiled.LinearExpression) r = append(r, v) - s += len(v.LinExp) + s += len(v) } // e(i1) // e(i2) @@ -681,9 +672,6 @@ func (system *compiler) NewHint(f hint.Function, nbOutputs int, inputs ...fronte // ensure inputs are set and pack them in a []uint64 for i, in := range inputs { switch t := in.(type) { - case compiled.Variable: - tmp := t.Clone() - hintInputs[i] = tmp.LinExp case compiled.LinearExpression: tmp := make(compiled.LinearExpression, len(t)) copy(tmp, t) @@ -698,7 +686,7 @@ func (system *compiler) NewHint(f hint.Function, nbOutputs int, inputs ...fronte res := make([]frontend.Variable, len(varIDs)) for i := range varIDs { r := system.newInternalVariable() - _, vID, _ := r.LinExp[0].Unpack() + _, vID, _ := r[0].Unpack() varIDs[i] = vID res[i] = r } diff --git a/frontend/cs/r1cs/r1cs_test.go b/frontend/cs/r1cs/r1cs_test.go index ca23706db2..97c8a2f6d9 100644 --- a/frontend/cs/r1cs/r1cs_test.go +++ b/frontend/cs/r1cs/r1cs_test.go @@ -63,10 +63,10 @@ func TestReduce(t *testing.T) { e := cs.Mul(z, 2) f := cs.Mul(z, 2) - toTest := (cs.Add(a, b, c, d, e, f)).(compiled.Variable) + toTest := (cs.Add(a, b, c, d, e, f)).(compiled.LinearExpression) // check sizes - if len(toTest.LinExp) != 3 { + if len(toTest) != 3 { t.Fatal("Error reduce, duplicate variables not collapsed") } diff --git a/frontend/cs/scs/compiler.go b/frontend/cs/scs/compiler.go index 441fcf754e..cad6a0786a 100644 --- a/frontend/cs/scs/compiler.go +++ b/frontend/cs/scs/compiler.go @@ -492,10 +492,6 @@ func (b *levelBuilder) processTerm(t compiled.Term, cID int) { for _, in := range h.Inputs { switch t := in.(type) { - case compiled.Variable: - for _, tt := range t.LinExp { - b.processTerm(tt, cID) - } case compiled.LinearExpression: for _, tt := range t { b.processTerm(tt, cID) @@ -608,8 +604,8 @@ func (system *compiler) NewHint(f hint.Function, nbOutputs int, inputs ...fronte } // returns in split into a slice of compiledTerm and the sum of all constants in in as a bigInt -func (system *compiler) filterConstantSum(in []frontend.Variable) ([]compiled.Term, big.Int) { - res := make([]compiled.Term, 0, len(in)) +func (system *compiler) filterConstantSum(in []frontend.Variable) (compiled.LinearExpression, big.Int) { + res := make(compiled.LinearExpression, 0, len(in)) var b big.Int for i := 0; i < len(in); i++ { switch t := in[i].(type) { @@ -624,8 +620,8 @@ func (system *compiler) filterConstantSum(in []frontend.Variable) ([]compiled.Te } // returns in split into a slice of compiledTerm and the product of all constants in in as a bigInt -func (system *compiler) filterConstantProd(in []frontend.Variable) ([]compiled.Term, big.Int) { - res := make([]compiled.Term, 0, len(in)) +func (system *compiler) filterConstantProd(in []frontend.Variable) (compiled.LinearExpression, big.Int) { + res := make(compiled.LinearExpression, 0, len(in)) var b big.Int b.SetInt64(1) for i := 0; i < len(in); i++ { @@ -640,7 +636,7 @@ func (system *compiler) filterConstantProd(in []frontend.Variable) ([]compiled.T return res, b } -func (system *compiler) splitSum(acc compiled.Term, r []compiled.Term) compiled.Term { +func (system *compiler) splitSum(acc compiled.Term, r compiled.LinearExpression) compiled.Term { // floor case if len(r) == 0 { @@ -654,7 +650,7 @@ func (system *compiler) splitSum(acc compiled.Term, r []compiled.Term) compiled. return system.splitSum(o, r[1:]) } -func (system *compiler) splitProd(acc compiled.Term, r []compiled.Term) compiled.Term { +func (system *compiler) splitProd(acc compiled.Term, r compiled.LinearExpression) compiled.Term { // floor case if len(r) == 0 { diff --git a/internal/backend/bls12-377/cs/r1cs.go b/internal/backend/bls12-377/cs/r1cs.go index a73732f8e8..85b3b3694a 100644 --- a/internal/backend/bls12-377/cs/r1cs.go +++ b/internal/backend/bls12-377/cs/r1cs.go @@ -302,15 +302,15 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr. return nil } - if err := processLExp(r.L.LinExp, a, 1); err != nil { + if err := processLExp(r.L, a, 1); err != nil { return err } - if err := processLExp(r.R.LinExp, b, 2); err != nil { + if err := processLExp(r.R, b, 2); err != nil { return err } - if err := processLExp(r.O.LinExp, c, 3); err != nil { + if err := processLExp(r.O, c, 3); err != nil { return err } @@ -387,11 +387,11 @@ func (cs *R1CS) GetConstraints() [][]string { return r } -func (cs *R1CS) vtoString(l compiled.Variable) string { +func (cs *R1CS) vtoString(l compiled.LinearExpression) string { var sbb strings.Builder - for i := 0; i < len(l.LinExp); i++ { - cs.termToString(l.LinExp[i], &sbb) - if i+1 < len(l.LinExp) { + for i := 0; i < len(l); i++ { + cs.termToString(l[i], &sbb) + if i+1 < len(l) { sbb.WriteString(" + ") } } diff --git a/internal/backend/bls12-377/cs/solution.go b/internal/backend/bls12-377/cs/solution.go index 80c77b6a9f..b3d63f66d9 100644 --- a/internal/backend/bls12-377/cs/solution.go +++ b/internal/backend/bls12-377/cs/solution.go @@ -170,9 +170,6 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { for i := 0; i < nbInputs; i++ { switch t := h.Inputs[i].(type) { - case compiled.Variable: - v := s.computeLinearExpression(t.LinExp) - v.ToBigIntRegular(inputs[i]) case compiled.LinearExpression: v := s.computeLinearExpression(t) v.ToBigIntRegular(inputs[i]) diff --git a/internal/backend/bls12-377/groth16/setup.go b/internal/backend/bls12-377/groth16/setup.go index 16e7b6926d..95112cddee 100644 --- a/internal/backend/bls12-377/groth16/setup.go +++ b/internal/backend/bls12-377/groth16/setup.go @@ -336,13 +336,13 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // A, B or C at the indice of the variable for i, c := range r1cs.Constraints { - for _, t := range c.L.LinExp { + for _, t := range c.L { accumulate(&A[t.WireID()], t, &L) } - for _, t := range c.R.LinExp { + for _, t := range c.R { accumulate(&B[t.WireID()], t, &L) } - for _, t := range c.O.LinExp { + for _, t := range c.O { accumulate(&C[t.WireID()], t, &L) } @@ -491,10 +491,10 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { A := make([]bool, nbWires) B := make([]bool, nbWires) for _, c := range r1cs.Constraints { - for _, t := range c.L.LinExp { + for _, t := range c.L { A[t.WireID()] = true } - for _, t := range c.R.LinExp { + for _, t := range c.R { B[t.WireID()] = true } } diff --git a/internal/backend/bls12-381/cs/r1cs.go b/internal/backend/bls12-381/cs/r1cs.go index 56f2b50f40..4beef74f02 100644 --- a/internal/backend/bls12-381/cs/r1cs.go +++ b/internal/backend/bls12-381/cs/r1cs.go @@ -302,15 +302,15 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr. return nil } - if err := processLExp(r.L.LinExp, a, 1); err != nil { + if err := processLExp(r.L, a, 1); err != nil { return err } - if err := processLExp(r.R.LinExp, b, 2); err != nil { + if err := processLExp(r.R, b, 2); err != nil { return err } - if err := processLExp(r.O.LinExp, c, 3); err != nil { + if err := processLExp(r.O, c, 3); err != nil { return err } @@ -387,11 +387,11 @@ func (cs *R1CS) GetConstraints() [][]string { return r } -func (cs *R1CS) vtoString(l compiled.Variable) string { +func (cs *R1CS) vtoString(l compiled.LinearExpression) string { var sbb strings.Builder - for i := 0; i < len(l.LinExp); i++ { - cs.termToString(l.LinExp[i], &sbb) - if i+1 < len(l.LinExp) { + for i := 0; i < len(l); i++ { + cs.termToString(l[i], &sbb) + if i+1 < len(l) { sbb.WriteString(" + ") } } diff --git a/internal/backend/bls12-381/cs/solution.go b/internal/backend/bls12-381/cs/solution.go index 3304585789..d577ae48d9 100644 --- a/internal/backend/bls12-381/cs/solution.go +++ b/internal/backend/bls12-381/cs/solution.go @@ -170,9 +170,6 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { for i := 0; i < nbInputs; i++ { switch t := h.Inputs[i].(type) { - case compiled.Variable: - v := s.computeLinearExpression(t.LinExp) - v.ToBigIntRegular(inputs[i]) case compiled.LinearExpression: v := s.computeLinearExpression(t) v.ToBigIntRegular(inputs[i]) diff --git a/internal/backend/bls12-381/groth16/setup.go b/internal/backend/bls12-381/groth16/setup.go index e30ddcff5c..b76aa9c87f 100644 --- a/internal/backend/bls12-381/groth16/setup.go +++ b/internal/backend/bls12-381/groth16/setup.go @@ -336,13 +336,13 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // A, B or C at the indice of the variable for i, c := range r1cs.Constraints { - for _, t := range c.L.LinExp { + for _, t := range c.L { accumulate(&A[t.WireID()], t, &L) } - for _, t := range c.R.LinExp { + for _, t := range c.R { accumulate(&B[t.WireID()], t, &L) } - for _, t := range c.O.LinExp { + for _, t := range c.O { accumulate(&C[t.WireID()], t, &L) } @@ -491,10 +491,10 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { A := make([]bool, nbWires) B := make([]bool, nbWires) for _, c := range r1cs.Constraints { - for _, t := range c.L.LinExp { + for _, t := range c.L { A[t.WireID()] = true } - for _, t := range c.R.LinExp { + for _, t := range c.R { B[t.WireID()] = true } } diff --git a/internal/backend/bls24-315/cs/r1cs.go b/internal/backend/bls24-315/cs/r1cs.go index 8fb821ccb3..a66744f5f0 100644 --- a/internal/backend/bls24-315/cs/r1cs.go +++ b/internal/backend/bls24-315/cs/r1cs.go @@ -302,15 +302,15 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr. return nil } - if err := processLExp(r.L.LinExp, a, 1); err != nil { + if err := processLExp(r.L, a, 1); err != nil { return err } - if err := processLExp(r.R.LinExp, b, 2); err != nil { + if err := processLExp(r.R, b, 2); err != nil { return err } - if err := processLExp(r.O.LinExp, c, 3); err != nil { + if err := processLExp(r.O, c, 3); err != nil { return err } @@ -387,11 +387,11 @@ func (cs *R1CS) GetConstraints() [][]string { return r } -func (cs *R1CS) vtoString(l compiled.Variable) string { +func (cs *R1CS) vtoString(l compiled.LinearExpression) string { var sbb strings.Builder - for i := 0; i < len(l.LinExp); i++ { - cs.termToString(l.LinExp[i], &sbb) - if i+1 < len(l.LinExp) { + for i := 0; i < len(l); i++ { + cs.termToString(l[i], &sbb) + if i+1 < len(l) { sbb.WriteString(" + ") } } diff --git a/internal/backend/bls24-315/cs/solution.go b/internal/backend/bls24-315/cs/solution.go index 6d10576504..8e4d24cd8a 100644 --- a/internal/backend/bls24-315/cs/solution.go +++ b/internal/backend/bls24-315/cs/solution.go @@ -170,9 +170,6 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { for i := 0; i < nbInputs; i++ { switch t := h.Inputs[i].(type) { - case compiled.Variable: - v := s.computeLinearExpression(t.LinExp) - v.ToBigIntRegular(inputs[i]) case compiled.LinearExpression: v := s.computeLinearExpression(t) v.ToBigIntRegular(inputs[i]) diff --git a/internal/backend/bls24-315/groth16/setup.go b/internal/backend/bls24-315/groth16/setup.go index 3a7dea1d87..ad5252165e 100644 --- a/internal/backend/bls24-315/groth16/setup.go +++ b/internal/backend/bls24-315/groth16/setup.go @@ -336,13 +336,13 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // A, B or C at the indice of the variable for i, c := range r1cs.Constraints { - for _, t := range c.L.LinExp { + for _, t := range c.L { accumulate(&A[t.WireID()], t, &L) } - for _, t := range c.R.LinExp { + for _, t := range c.R { accumulate(&B[t.WireID()], t, &L) } - for _, t := range c.O.LinExp { + for _, t := range c.O { accumulate(&C[t.WireID()], t, &L) } @@ -491,10 +491,10 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { A := make([]bool, nbWires) B := make([]bool, nbWires) for _, c := range r1cs.Constraints { - for _, t := range c.L.LinExp { + for _, t := range c.L { A[t.WireID()] = true } - for _, t := range c.R.LinExp { + for _, t := range c.R { B[t.WireID()] = true } } diff --git a/internal/backend/bn254/cs/r1cs.go b/internal/backend/bn254/cs/r1cs.go index 27874fd9ed..4db2916282 100644 --- a/internal/backend/bn254/cs/r1cs.go +++ b/internal/backend/bn254/cs/r1cs.go @@ -302,15 +302,15 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr. return nil } - if err := processLExp(r.L.LinExp, a, 1); err != nil { + if err := processLExp(r.L, a, 1); err != nil { return err } - if err := processLExp(r.R.LinExp, b, 2); err != nil { + if err := processLExp(r.R, b, 2); err != nil { return err } - if err := processLExp(r.O.LinExp, c, 3); err != nil { + if err := processLExp(r.O, c, 3); err != nil { return err } @@ -387,11 +387,11 @@ func (cs *R1CS) GetConstraints() [][]string { return r } -func (cs *R1CS) vtoString(l compiled.Variable) string { +func (cs *R1CS) vtoString(l compiled.LinearExpression) string { var sbb strings.Builder - for i := 0; i < len(l.LinExp); i++ { - cs.termToString(l.LinExp[i], &sbb) - if i+1 < len(l.LinExp) { + for i := 0; i < len(l); i++ { + cs.termToString(l[i], &sbb) + if i+1 < len(l) { sbb.WriteString(" + ") } } diff --git a/internal/backend/bn254/cs/solution.go b/internal/backend/bn254/cs/solution.go index b58d02cbe6..1d8d337c0c 100644 --- a/internal/backend/bn254/cs/solution.go +++ b/internal/backend/bn254/cs/solution.go @@ -170,9 +170,6 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { for i := 0; i < nbInputs; i++ { switch t := h.Inputs[i].(type) { - case compiled.Variable: - v := s.computeLinearExpression(t.LinExp) - v.ToBigIntRegular(inputs[i]) case compiled.LinearExpression: v := s.computeLinearExpression(t) v.ToBigIntRegular(inputs[i]) diff --git a/internal/backend/bn254/groth16/setup.go b/internal/backend/bn254/groth16/setup.go index a5280b909b..2ed4f2e6ec 100644 --- a/internal/backend/bn254/groth16/setup.go +++ b/internal/backend/bn254/groth16/setup.go @@ -336,13 +336,13 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // A, B or C at the indice of the variable for i, c := range r1cs.Constraints { - for _, t := range c.L.LinExp { + for _, t := range c.L { accumulate(&A[t.WireID()], t, &L) } - for _, t := range c.R.LinExp { + for _, t := range c.R { accumulate(&B[t.WireID()], t, &L) } - for _, t := range c.O.LinExp { + for _, t := range c.O { accumulate(&C[t.WireID()], t, &L) } @@ -491,10 +491,10 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { A := make([]bool, nbWires) B := make([]bool, nbWires) for _, c := range r1cs.Constraints { - for _, t := range c.L.LinExp { + for _, t := range c.L { A[t.WireID()] = true } - for _, t := range c.R.LinExp { + for _, t := range c.R { B[t.WireID()] = true } } diff --git a/internal/backend/bw6-633/cs/r1cs.go b/internal/backend/bw6-633/cs/r1cs.go index 4f6afe073b..ea70db009c 100644 --- a/internal/backend/bw6-633/cs/r1cs.go +++ b/internal/backend/bw6-633/cs/r1cs.go @@ -302,15 +302,15 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr. return nil } - if err := processLExp(r.L.LinExp, a, 1); err != nil { + if err := processLExp(r.L, a, 1); err != nil { return err } - if err := processLExp(r.R.LinExp, b, 2); err != nil { + if err := processLExp(r.R, b, 2); err != nil { return err } - if err := processLExp(r.O.LinExp, c, 3); err != nil { + if err := processLExp(r.O, c, 3); err != nil { return err } @@ -387,11 +387,11 @@ func (cs *R1CS) GetConstraints() [][]string { return r } -func (cs *R1CS) vtoString(l compiled.Variable) string { +func (cs *R1CS) vtoString(l compiled.LinearExpression) string { var sbb strings.Builder - for i := 0; i < len(l.LinExp); i++ { - cs.termToString(l.LinExp[i], &sbb) - if i+1 < len(l.LinExp) { + for i := 0; i < len(l); i++ { + cs.termToString(l[i], &sbb) + if i+1 < len(l) { sbb.WriteString(" + ") } } diff --git a/internal/backend/bw6-633/cs/solution.go b/internal/backend/bw6-633/cs/solution.go index 529820639f..06247264c2 100644 --- a/internal/backend/bw6-633/cs/solution.go +++ b/internal/backend/bw6-633/cs/solution.go @@ -170,9 +170,6 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { for i := 0; i < nbInputs; i++ { switch t := h.Inputs[i].(type) { - case compiled.Variable: - v := s.computeLinearExpression(t.LinExp) - v.ToBigIntRegular(inputs[i]) case compiled.LinearExpression: v := s.computeLinearExpression(t) v.ToBigIntRegular(inputs[i]) diff --git a/internal/backend/bw6-633/groth16/setup.go b/internal/backend/bw6-633/groth16/setup.go index 9a6b230ed7..27e2db3700 100644 --- a/internal/backend/bw6-633/groth16/setup.go +++ b/internal/backend/bw6-633/groth16/setup.go @@ -336,13 +336,13 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // A, B or C at the indice of the variable for i, c := range r1cs.Constraints { - for _, t := range c.L.LinExp { + for _, t := range c.L { accumulate(&A[t.WireID()], t, &L) } - for _, t := range c.R.LinExp { + for _, t := range c.R { accumulate(&B[t.WireID()], t, &L) } - for _, t := range c.O.LinExp { + for _, t := range c.O { accumulate(&C[t.WireID()], t, &L) } @@ -491,10 +491,10 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { A := make([]bool, nbWires) B := make([]bool, nbWires) for _, c := range r1cs.Constraints { - for _, t := range c.L.LinExp { + for _, t := range c.L { A[t.WireID()] = true } - for _, t := range c.R.LinExp { + for _, t := range c.R { B[t.WireID()] = true } } diff --git a/internal/backend/bw6-761/cs/r1cs.go b/internal/backend/bw6-761/cs/r1cs.go index e774f7408c..bb29df9a55 100644 --- a/internal/backend/bw6-761/cs/r1cs.go +++ b/internal/backend/bw6-761/cs/r1cs.go @@ -302,15 +302,15 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr. return nil } - if err := processLExp(r.L.LinExp, a, 1); err != nil { + if err := processLExp(r.L, a, 1); err != nil { return err } - if err := processLExp(r.R.LinExp, b, 2); err != nil { + if err := processLExp(r.R, b, 2); err != nil { return err } - if err := processLExp(r.O.LinExp, c, 3); err != nil { + if err := processLExp(r.O, c, 3); err != nil { return err } @@ -387,11 +387,11 @@ func (cs *R1CS) GetConstraints() [][]string { return r } -func (cs *R1CS) vtoString(l compiled.Variable) string { +func (cs *R1CS) vtoString(l compiled.LinearExpression) string { var sbb strings.Builder - for i := 0; i < len(l.LinExp); i++ { - cs.termToString(l.LinExp[i], &sbb) - if i+1 < len(l.LinExp) { + for i := 0; i < len(l); i++ { + cs.termToString(l[i], &sbb) + if i+1 < len(l) { sbb.WriteString(" + ") } } diff --git a/internal/backend/bw6-761/cs/solution.go b/internal/backend/bw6-761/cs/solution.go index 81a0600110..d0b6dcf512 100644 --- a/internal/backend/bw6-761/cs/solution.go +++ b/internal/backend/bw6-761/cs/solution.go @@ -170,9 +170,6 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { for i := 0; i < nbInputs; i++ { switch t := h.Inputs[i].(type) { - case compiled.Variable: - v := s.computeLinearExpression(t.LinExp) - v.ToBigIntRegular(inputs[i]) case compiled.LinearExpression: v := s.computeLinearExpression(t) v.ToBigIntRegular(inputs[i]) diff --git a/internal/backend/bw6-761/groth16/setup.go b/internal/backend/bw6-761/groth16/setup.go index 06127edd0b..1f8d2d84ad 100644 --- a/internal/backend/bw6-761/groth16/setup.go +++ b/internal/backend/bw6-761/groth16/setup.go @@ -336,13 +336,13 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // A, B or C at the indice of the variable for i, c := range r1cs.Constraints { - for _, t := range c.L.LinExp { + for _, t := range c.L { accumulate(&A[t.WireID()], t, &L) } - for _, t := range c.R.LinExp { + for _, t := range c.R { accumulate(&B[t.WireID()], t, &L) } - for _, t := range c.O.LinExp { + for _, t := range c.O { accumulate(&C[t.WireID()], t, &L) } @@ -491,10 +491,10 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { A := make([]bool, nbWires) B := make([]bool, nbWires) for _, c := range r1cs.Constraints { - for _, t := range c.L.LinExp { + for _, t := range c.L { A[t.WireID()] = true } - for _, t := range c.R.LinExp { + for _, t := range c.R { B[t.WireID()] = true } } diff --git a/internal/generator/backend/template/representations/r1cs.go.tmpl b/internal/generator/backend/template/representations/r1cs.go.tmpl index 45dd4fc4f6..09b1518a13 100644 --- a/internal/generator/backend/template/representations/r1cs.go.tmpl +++ b/internal/generator/backend/template/representations/r1cs.go.tmpl @@ -292,15 +292,15 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a,b,c *fr.El return nil } - if err := processLExp(r.L.LinExp, a, 1); err != nil { + if err := processLExp(r.L, a, 1); err != nil { return err } - if err := processLExp(r.R.LinExp, b, 2); err != nil { + if err := processLExp(r.R, b, 2); err != nil { return err } - if err := processLExp(r.O.LinExp, c, 3); err != nil { + if err := processLExp(r.O, c, 3); err != nil { return err } @@ -379,11 +379,11 @@ func (cs *R1CS) GetConstraints() [][]string { return r } -func (cs *R1CS) vtoString(l compiled.Variable) string { +func (cs *R1CS) vtoString(l compiled.LinearExpression) string { var sbb strings.Builder - for i := 0; i < len(l.LinExp); i++ { - cs.termToString(l.LinExp[i], &sbb) - if i+1 < len(l.LinExp) { + for i := 0; i < len(l); i++ { + cs.termToString(l[i], &sbb) + if i+1 < len(l) { sbb.WriteString(" + ") } } diff --git a/internal/generator/backend/template/representations/solution.go.tmpl b/internal/generator/backend/template/representations/solution.go.tmpl index 4929438644..98cca89b1a 100644 --- a/internal/generator/backend/template/representations/solution.go.tmpl +++ b/internal/generator/backend/template/representations/solution.go.tmpl @@ -151,9 +151,6 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { for i := 0; i < nbInputs; i++ { switch t := h.Inputs[i].(type) { - case compiled.Variable: - v := s.computeLinearExpression(t.LinExp) - v.ToBigIntRegular(inputs[i]) case compiled.LinearExpression: v := s.computeLinearExpression(t) v.ToBigIntRegular(inputs[i]) diff --git a/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl index 5d25b59ff5..3a9719888e 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl @@ -315,13 +315,13 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // A, B or C at the indice of the variable for i, c := range r1cs.Constraints { - for _, t := range c.L.LinExp { + for _, t := range c.L { accumulate(&A[t.WireID()], t, &L) } - for _, t := range c.R.LinExp { + for _, t := range c.R { accumulate(&B[t.WireID()], t, &L) } - for _, t := range c.O.LinExp { + for _, t := range c.O { accumulate(&C[t.WireID()], t, &L) } @@ -470,10 +470,10 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { A := make([]bool, nbWires) B := make([]bool, nbWires) for _, c := range r1cs.Constraints { - for _, t := range c.L.LinExp { + for _, t := range c.L { A[t.WireID()] = true } - for _, t := range c.R.LinExp { + for _, t := range c.R { B[t.WireID()] = true } } From 95c80f1e8d7c704862cc3be2807814b26d016556 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 28 Feb 2022 14:24:47 -0600 Subject: [PATCH 18/20] fix: incorrect handling of nbBits == 1 in api.ToBinary --- frontend/api.go | 8 -------- frontend/cs/r1cs/api.go | 2 +- frontend/cs/r1cs/compiler.go | 1 + frontend/cs/scs/api.go | 2 +- 4 files changed, 3 insertions(+), 10 deletions(-) diff --git a/frontend/api.go b/frontend/api.go index 7f79decb97..2c1961eb15 100644 --- a/frontend/api.go +++ b/frontend/api.go @@ -117,14 +117,6 @@ type API interface { // Deprecated APIs - // MarkBoolean is a shorcut to api.Compiler().MarkBoolean() - // Deprecated: use api.Compiler().MarkBoolean() instead - MarkBoolean(v Variable) - - // IsBoolean is a shorcut to api.Compiler().IsBoolean() - // Deprecated: use api.Compiler().IsBoolean() instead - IsBoolean(v Variable) bool - // NewHint is a shorcut to api.Compiler().NewHint() // Deprecated: use api.Compiler().NewHint() instead NewHint(f hint.Function, nbOutputs int, inputs ...Variable) ([]Variable, error) diff --git a/frontend/cs/r1cs/api.go b/frontend/cs/r1cs/api.go index c0be62114e..f917863110 100644 --- a/frontend/cs/r1cs/api.go +++ b/frontend/cs/r1cs/api.go @@ -309,7 +309,7 @@ func (system *compiler) toBinary(a compiled.LinearExpression, nbBits int, unsafe //var Σbi compiled.LinearExpression var Σbi frontend.Variable if nbBits == 1 { - system.AssertIsEqual(sb[0], a) + Σbi = sb[0] } else if nbBits == 2 { Σbi = system.Add(sb[0], sb[1]) } else { diff --git a/frontend/cs/r1cs/compiler.go b/frontend/cs/r1cs/compiler.go index 74a88f5fef..035ed9da76 100644 --- a/frontend/cs/r1cs/compiler.go +++ b/frontend/cs/r1cs/compiler.go @@ -673,6 +673,7 @@ func (system *compiler) NewHint(f hint.Function, nbOutputs int, inputs ...fronte for i, in := range inputs { switch t := in.(type) { case compiled.LinearExpression: + t.AssertIsSet() tmp := make(compiled.LinearExpression, len(t)) copy(tmp, t) hintInputs[i] = tmp diff --git a/frontend/cs/scs/api.go b/frontend/cs/scs/api.go index 67796cb23e..240fbd4f77 100644 --- a/frontend/cs/scs/api.go +++ b/frontend/cs/scs/api.go @@ -221,7 +221,7 @@ func (system *compiler) toBinary(a compiled.Term, nbBits int, unsafe bool) []fro // TODO we can save a constraint here var Σbi frontend.Variable if nbBits == 1 { - system.AssertIsEqual(sb[0], a) + Σbi = sb[0] } else if nbBits == 2 { Σbi = system.Add(sb[0], sb[1]) } else { From fbae8eca276d6569f20ddd0f750e56d0135289bf Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 28 Feb 2022 14:46:41 -0600 Subject: [PATCH 19/20] style: code cleaning --- frontend/compiled/linear_expression.go | 87 ++++++++++++++++++++++++++ frontend/compiled/symbol.go | 6 -- frontend/compiled/term.go | 51 --------------- frontend/compiled/variable.go | 53 ---------------- frontend/cs/r1cs/api.go | 10 +-- frontend/cs/r1cs/api_assertions.go | 2 +- frontend/cs/r1cs/compiler.go | 21 ++++++- 7 files changed, 107 insertions(+), 123 deletions(-) create mode 100644 frontend/compiled/linear_expression.go delete mode 100644 frontend/compiled/symbol.go delete mode 100644 frontend/compiled/variable.go diff --git a/frontend/compiled/linear_expression.go b/frontend/compiled/linear_expression.go new file mode 100644 index 0000000000..46baa9c868 --- /dev/null +++ b/frontend/compiled/linear_expression.go @@ -0,0 +1,87 @@ +// Copyright 2021 ConsenSys AG +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compiled + +import ( + "math/big" + "strings" +) + +// A linear expression is a linear combination of Term +type LinearExpression []Term + +// Clone returns a copy of the underlying slice +func (l LinearExpression) Clone() LinearExpression { + res := make(LinearExpression, len(l)) + copy(res, l) + return res +} + +func (l LinearExpression) string(sbb *strings.Builder, coeffs []big.Int) { + for i := 0; i < len(l); i++ { + l[i].string(sbb, coeffs) + if i+1 < len(l) { + sbb.WriteString(" + ") + } + } +} + +// Len return the lenght of the Variable (implements Sort interface) +func (l LinearExpression) Len() int { + return len(l) +} + +// Equals returns true if both SORTED expressions are the same +// +// pre conditions: l and o are sorted +func (l LinearExpression) Equal(o LinearExpression) bool { + if len(l) != len(o) { + return false + } + if (l == nil) != (o == nil) { + return false + } + for i := 0; i < len(l); i++ { + if l[i] != o[i] { + return false + } + } + return true +} + +// Swap swaps terms in the Variable (implements Sort interface) +func (l LinearExpression) Swap(i, j int) { + l[i], l[j] = l[j], l[i] +} + +// Less returns true if variableID for term at i is less than variableID for term at j (implements Sort interface) +func (l LinearExpression) Less(i, j int) bool { + _, iID, iVis := l[i].Unpack() + _, jID, jVis := l[j].Unpack() + if iVis == jVis { + return iID < jID + } + return iVis > jVis +} + +// HashCode returns a fast-to-compute but NOT collision resistant hash code identifier for the linear +// expression +func (l LinearExpression) HashCode() uint64 { + h := uint64(17) + for _, val := range l { + h = h*23 + uint64(val) + } + return h +} diff --git a/frontend/compiled/symbol.go b/frontend/compiled/symbol.go deleted file mode 100644 index 62675d97af..0000000000 --- a/frontend/compiled/symbol.go +++ /dev/null @@ -1,6 +0,0 @@ -package compiled - -type Symbol interface { - AssertIsSet() - IsConstant() bool -} diff --git a/frontend/compiled/term.go b/frontend/compiled/term.go index 1a84217e15..08148636e2 100644 --- a/frontend/compiled/term.go +++ b/frontend/compiled/term.go @@ -27,9 +27,6 @@ import ( // note: if we support more than 1 billion constraints, this breaks (not so soon.) type Term uint64 -// A linear expression is a linear combination of Term -type LinearExpression []Term - // ids of the coefficients with simple values in any cs.coeffs slice. const ( CoeffIdZero = 0 @@ -179,51 +176,3 @@ func (t Term) string(sbb *strings.Builder, coeffs []big.Int) { } sbb.WriteString(strconv.Itoa(t.WireID())) } - -// Len return the lenght of the Variable (implements Sort interface) -func (v LinearExpression) Len() int { - return len(v) -} - -// Equals returns true if both SORTED expressions are the same -// -// pre conditions: l and o are sorted -func (v LinearExpression) Equal(o LinearExpression) bool { - if len(v) != len(o) { - return false - } - if (v == nil) != (o == nil) { - return false - } - for i := 0; i < len(v); i++ { - if v[i] != o[i] { - return false - } - } - return true -} - -// Swap swaps terms in the Variable (implements Sort interface) -func (v LinearExpression) Swap(i, j int) { - v[i], v[j] = v[j], v[i] -} - -// Less returns true if variableID for term at i is less than variableID for term at j (implements Sort interface) -func (v LinearExpression) Less(i, j int) bool { - _, iID, iVis := v[i].Unpack() - _, jID, jVis := v[j].Unpack() - if iVis == jVis { - return iID < jID - } - return iVis > jVis -} - -// HashCode returns a fast-to-compute but NOT collision resistant hash code identifier for the linear -// expression -func (v LinearExpression) HashCode() uint64 { - h := uint64(17) - for _, val := range v { - h = h*23 + uint64(val) - } - return h -} diff --git a/frontend/compiled/variable.go b/frontend/compiled/variable.go deleted file mode 100644 index f4d329fdcf..0000000000 --- a/frontend/compiled/variable.go +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2021 ConsenSys AG -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package compiled - -import ( - "errors" - "math/big" - "strings" -) - -// errNoValue triggered when trying to access a variable that was not allocated -var errNoValue = errors.New("can't determine API input value") - -// Clone returns a copy of the underlying slice -func (v LinearExpression) Clone() LinearExpression { - res := make(LinearExpression, len(v)) - copy(res, v) - return res -} - -func (v LinearExpression) string(sbb *strings.Builder, coeffs []big.Int) { - for i := 0; i < len(v); i++ { - v[i].string(sbb, coeffs) - if i+1 < len(v) { - sbb.WriteString(" + ") - } - } -} - -// assertIsSet panics if the variable is unset -// this may happen if inside a Define we have -// var a variable -// cs.Mul(a, 1) -// since a was not in the circuit struct it is not a secret variable -func (v LinearExpression) AssertIsSet() { - - if len(v) == 0 { - panic(errNoValue) - } - -} diff --git a/frontend/cs/r1cs/api.go b/frontend/cs/r1cs/api.go index f917863110..da921bc727 100644 --- a/frontend/cs/r1cs/api.go +++ b/frontend/cs/r1cs/api.go @@ -285,9 +285,6 @@ func (system *compiler) toBinary(a compiled.LinearExpression, nbBits int, unsafe return system.ToBinary(a, nbBits) } - // ensure a is set - a.AssertIsSet() - // allocate the resulting frontend.Variables and bit-constraint them sb := make([]frontend.Variable, nbBits) var c big.Int @@ -326,11 +323,6 @@ func (system *compiler) toBinary(a compiled.LinearExpression, nbBits int, unsafe func (system *compiler) FromBinary(_b ...frontend.Variable) frontend.Variable { b, _ := system.toVariables(_b...) - // ensure inputs are set - for i := 0; i < len(b); i++ { - b[i].AssertIsSet() - } - // res = Σ (2**i * b[i]) var res, v frontend.Variable @@ -556,7 +548,7 @@ func (system *compiler) Println(a ...frontend.Variable) { sbb.WriteByte(' ') } if v, ok := arg.(compiled.LinearExpression); ok { - v.AssertIsSet() + assertIsSet(v) sbb.WriteString("%s") // we set limits to the linear expression, so that the log printer diff --git a/frontend/cs/r1cs/api_assertions.go b/frontend/cs/r1cs/api_assertions.go index fae3e601f8..6b984a2142 100644 --- a/frontend/cs/r1cs/api_assertions.go +++ b/frontend/cs/r1cs/api_assertions.go @@ -79,7 +79,7 @@ func (system *compiler) AssertIsLessOrEqual(_v frontend.Variable, bound frontend switch b := bound.(type) { case compiled.LinearExpression: - b.AssertIsSet() + assertIsSet(b) system.mustBeLessOrEqVar(v[0], b) default: system.mustBeLessOrEqCst(v[0], utils.FromInterface(b)) diff --git a/frontend/cs/r1cs/compiler.go b/frontend/cs/r1cs/compiler.go index 035ed9da76..64a785bcc1 100644 --- a/frontend/cs/r1cs/compiler.go +++ b/frontend/cs/r1cs/compiler.go @@ -569,7 +569,7 @@ func (b *levelBuilder) processLE(l compiled.LinearExpression, cID int) { // Will panic if v.IsConstant() == false func (system *compiler) ConstantValue(v frontend.Variable) (*big.Int, bool) { if _v, ok := v.(compiled.LinearExpression); ok { - _v.AssertIsSet() + assertIsSet(_v) if len(_v) != 1 { return nil, false @@ -596,7 +596,7 @@ func (system *compiler) toVariable(input interface{}) frontend.Variable { switch t := input.(type) { case compiled.LinearExpression: - t.AssertIsSet() + assertIsSet(t) return t default: n := utils.FromInterface(t) @@ -673,7 +673,7 @@ func (system *compiler) NewHint(f hint.Function, nbOutputs int, inputs ...fronte for i, in := range inputs { switch t := in.(type) { case compiled.LinearExpression: - t.AssertIsSet() + assertIsSet(t) tmp := make(compiled.LinearExpression, len(t)) copy(tmp, t) hintInputs[i] = tmp @@ -699,3 +699,18 @@ func (system *compiler) NewHint(f hint.Function, nbOutputs int, inputs ...fronte return res, nil } + +// assertIsSet panics if the variable is unset +// this may happen if inside a Define we have +// var a variable +// cs.Mul(a, 1) +// since a was not in the circuit struct it is not a secret variable +func assertIsSet(l compiled.LinearExpression) { + // TODO PlonK scs doesn't have a similar check with compiled.Term == 0 + if len(l) == 0 { + // errNoValue triggered when trying to access a variable that was not allocated + errNoValue := errors.New("can't determine API input value") + panic(errNoValue) + } + +} From 88838e83530ffbc262675ee9a49e1b3338dfc917 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Tue, 1 Mar 2022 11:36:15 -0600 Subject: [PATCH 20/20] refactor: compiler -> r1cs and scs internally --- frontend/cs/r1cs/api.go | 42 +++++++++++++-------------- frontend/cs/r1cs/api_assertions.go | 12 ++++---- frontend/cs/r1cs/compiler.go | 44 ++++++++++++++-------------- frontend/cs/scs/api.go | 42 +++++++++++++-------------- frontend/cs/scs/api_assertions.go | 12 ++++---- frontend/cs/scs/compiler.go | 46 +++++++++++++++--------------- 6 files changed, 99 insertions(+), 99 deletions(-) diff --git a/frontend/cs/r1cs/api.go b/frontend/cs/r1cs/api.go index da921bc727..e437a5bce2 100644 --- a/frontend/cs/r1cs/api.go +++ b/frontend/cs/r1cs/api.go @@ -35,7 +35,7 @@ import ( // Arithmetic // Add returns res = i1+i2+...in -func (system *compiler) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { +func (system *r1cs) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { // extract frontend.Variables from input vars, s := system.toVariables(append([]frontend.Variable{i1, i2}, in...)...) @@ -54,7 +54,7 @@ func (system *compiler) Add(i1, i2 frontend.Variable, in ...frontend.Variable) f } // Neg returns -i -func (system *compiler) Neg(i frontend.Variable) frontend.Variable { +func (system *r1cs) Neg(i frontend.Variable) frontend.Variable { vars, _ := system.toVariables(i) if n, ok := system.ConstantValue(vars[0]); ok { @@ -66,7 +66,7 @@ func (system *compiler) Neg(i frontend.Variable) frontend.Variable { } // Sub returns res = i1 - i2 -func (system *compiler) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { +func (system *r1cs) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { // extract frontend.Variables from input vars, s := system.toVariables(append([]frontend.Variable{i1, i2}, in...)...) @@ -88,7 +88,7 @@ func (system *compiler) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) f } // Mul returns res = i1 * i2 * ... in -func (system *compiler) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { +func (system *r1cs) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { vars, _ := system.toVariables(append([]frontend.Variable{i1, i2}, in...)...) mul := func(v1, v2 compiled.LinearExpression) compiled.LinearExpression { @@ -126,7 +126,7 @@ func (system *compiler) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) f return res } -func (system *compiler) mulConstant(v1, constant compiled.LinearExpression) compiled.LinearExpression { +func (system *r1cs) mulConstant(v1, constant compiled.LinearExpression) compiled.LinearExpression { // multiplying a frontend.Variable by a constant -> we updated the coefficients in the linear expression // leading to that frontend.Variable res := v1.Clone() @@ -153,7 +153,7 @@ func (system *compiler) mulConstant(v1, constant compiled.LinearExpression) comp return res } -func (system *compiler) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable { +func (system *r1cs) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable { vars, _ := system.toVariables(i1, i2) v1 := vars[0] @@ -187,7 +187,7 @@ func (system *compiler) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable } // Div returns res = i1 / i2 -func (system *compiler) Div(i1, i2 frontend.Variable) frontend.Variable { +func (system *r1cs) Div(i1, i2 frontend.Variable) frontend.Variable { vars, _ := system.toVariables(i1, i2) v1 := vars[0] @@ -223,7 +223,7 @@ func (system *compiler) Div(i1, i2 frontend.Variable) frontend.Variable { } // Inverse returns res = inverse(v) -func (system *compiler) Inverse(i1 frontend.Variable) frontend.Variable { +func (system *r1cs) Inverse(i1 frontend.Variable) frontend.Variable { vars, _ := system.toVariables(i1) if c, ok := system.ConstantValue(vars[0]); ok { @@ -252,7 +252,7 @@ func (system *compiler) Inverse(i1 frontend.Variable) frontend.Variable { // n default value is fr.Bits the number of bits needed to represent a field element // // The result in in little endian (first bit= lsb) -func (system *compiler) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable { +func (system *r1cs) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable { // nbBits nbBits := system.BitLen() @@ -279,7 +279,7 @@ func (system *compiler) ToBinary(i1 frontend.Variable, n ...int) []frontend.Vari } // toBinary is equivalent to ToBinary, exept the returned bits are NOT boolean constrained. -func (system *compiler) toBinary(a compiled.LinearExpression, nbBits int, unsafe bool) []frontend.Variable { +func (system *r1cs) toBinary(a compiled.LinearExpression, nbBits int, unsafe bool) []frontend.Variable { if _, ok := system.ConstantValue(a); ok { return system.ToBinary(a, nbBits) @@ -320,7 +320,7 @@ func (system *compiler) toBinary(a compiled.LinearExpression, nbBits int, unsafe } // FromBinary packs b, seen as a fr.Element in little endian -func (system *compiler) FromBinary(_b ...frontend.Variable) frontend.Variable { +func (system *r1cs) FromBinary(_b ...frontend.Variable) frontend.Variable { b, _ := system.toVariables(_b...) // res = Σ (2**i * b[i]) @@ -343,7 +343,7 @@ func (system *compiler) FromBinary(_b ...frontend.Variable) frontend.Variable { } // Xor compute the XOR between two frontend.Variables -func (system *compiler) Xor(_a, _b frontend.Variable) frontend.Variable { +func (system *r1cs) Xor(_a, _b frontend.Variable) frontend.Variable { vars, _ := system.toVariables(_a, _b) @@ -365,7 +365,7 @@ func (system *compiler) Xor(_a, _b frontend.Variable) frontend.Variable { } // Or compute the OR between two frontend.Variables -func (system *compiler) Or(_a, _b frontend.Variable) frontend.Variable { +func (system *r1cs) Or(_a, _b frontend.Variable) frontend.Variable { vars, _ := system.toVariables(_a, _b) a := vars[0] @@ -385,7 +385,7 @@ func (system *compiler) Or(_a, _b frontend.Variable) frontend.Variable { } // And compute the AND between two frontend.Variables -func (system *compiler) And(_a, _b frontend.Variable) frontend.Variable { +func (system *r1cs) And(_a, _b frontend.Variable) frontend.Variable { vars, _ := system.toVariables(_a, _b) a := vars[0] @@ -403,7 +403,7 @@ func (system *compiler) And(_a, _b frontend.Variable) frontend.Variable { // Conditionals // Select if i0 is true, yields i1 else yields i2 -func (system *compiler) Select(i0, i1, i2 frontend.Variable) frontend.Variable { +func (system *r1cs) Select(i0, i1, i2 frontend.Variable) frontend.Variable { vars, _ := system.toVariables(i0, i1, i2) b := vars[0] @@ -437,7 +437,7 @@ func (system *compiler) Select(i0, i1, i2 frontend.Variable) frontend.Variable { // Lookup2 performs a 2-bit lookup between i1, i2, i3, i4 based on bits b0 // and b1. Returns i0 if b0=b1=0, i1 if b0=1 and b1=0, i2 if b0=0 and b1=1 // and i3 if b0=b1=1. -func (system *compiler) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 frontend.Variable) frontend.Variable { +func (system *r1cs) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 frontend.Variable) frontend.Variable { vars, _ := system.toVariables(b0, b1, i0, i1, i2, i3) s0, s1 := vars[0], vars[1] in0, in1, in2, in3 := vars[2], vars[3], vars[4], vars[5] @@ -468,7 +468,7 @@ func (system *compiler) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 fronten } // IsZero returns 1 if i1 is zero, 0 otherwise -func (system *compiler) IsZero(i1 frontend.Variable) frontend.Variable { +func (system *r1cs) IsZero(i1 frontend.Variable) frontend.Variable { vars, _ := system.toVariables(i1) a := vars[0] if c, ok := system.ConstantValue(a); ok { @@ -500,7 +500,7 @@ func (system *compiler) IsZero(i1 frontend.Variable) frontend.Variable { } // Cmp returns 1 if i1>i2, 0 if i1=i2, -1 if i1 0 { system.MDebug[len(system.Constraints)-1] = debugID[0] @@ -181,7 +181,7 @@ func (system *compiler) addConstraint(r1c compiled.R1C, debugID ...int) { // Term packs a Variable and a coeff in a Term and returns it. // func (system *R1CSRefactor) setCoeff(v Variable, coeff *big.Int) Term { -func (system *compiler) setCoeff(v compiled.Term, coeff *big.Int) compiled.Term { +func (system *r1cs) setCoeff(v compiled.Term, coeff *big.Int) compiled.Term { _, vID, vVis := v.Unpack() return compiled.Pack(vID, system.st.CoeffID(coeff), vVis) } @@ -189,7 +189,7 @@ func (system *compiler) setCoeff(v compiled.Term, coeff *big.Int) compiled.Term // MarkBoolean sets (but do not **constraint**!) v to be boolean // This is useful in scenarios where a variable is known to be boolean through a constraint // that is not api.AssertIsBoolean. If v is a constant, this is a no-op. -func (system *compiler) MarkBoolean(v frontend.Variable) { +func (system *r1cs) MarkBoolean(v frontend.Variable) { if b, ok := system.ConstantValue(v); ok { if !(b.IsUint64() && b.Uint64() <= 1) { panic("MarkBoolean called a non-boolean constant") @@ -211,7 +211,7 @@ func (system *compiler) MarkBoolean(v frontend.Variable) { // IsBoolean returns true if given variable was marked as boolean in the compiler (see MarkBoolean) // Use with care; variable may not have been **constrained** to be boolean // This returns true if the v is a constant and v == 0 || v == 1. -func (system *compiler) IsBoolean(v frontend.Variable) bool { +func (system *r1cs) IsBoolean(v frontend.Variable) bool { if b, ok := system.ConstantValue(v); ok { return b.IsUint64() && b.Uint64() <= 1 } @@ -239,7 +239,7 @@ func (system *compiler) IsBoolean(v frontend.Variable) bool { // // 1. checks that all user inputs are referenced in at least one constraint // 2. checks that all hints are constrained -func (system *compiler) checkVariables() error { +func (system *r1cs) checkVariables() error { // TODO @gbotrel add unit test for that. @@ -343,7 +343,7 @@ func init() { } // Compile constructs a rank-1 constraint sytem -func (cs *compiler) Compile() (frontend.CompiledConstraintSystem, error) { +func (cs *r1cs) Compile() (frontend.CompiledConstraintSystem, error) { // ensure all inputs and hints are constrained if !cs.config.IgnoreUnconstrainedInputs { @@ -464,7 +464,7 @@ HINTLOOP: } } -func (cs *compiler) SetSchema(s *schema.Schema) { +func (cs *r1cs) SetSchema(s *schema.Schema) { if cs.Schema != nil { panic("SetSchema called multiple times") } @@ -567,7 +567,7 @@ func (b *levelBuilder) processLE(l compiled.LinearExpression, cID int) { // ConstantValue returns the big.Int value of v. // Will panic if v.IsConstant() == false -func (system *compiler) ConstantValue(v frontend.Variable) (*big.Int, bool) { +func (system *r1cs) ConstantValue(v frontend.Variable) (*big.Int, bool) { if _v, ok := v.(compiled.LinearExpression); ok { assertIsSet(_v) @@ -584,7 +584,7 @@ func (system *compiler) ConstantValue(v frontend.Variable) (*big.Int, bool) { return &r, true } -func (system *compiler) Backend() backend.ID { +func (system *r1cs) Backend() backend.ID { return backend.GROTH16 } @@ -592,7 +592,7 @@ func (system *compiler) Backend() backend.ID { // // if input is already a compiled.LinearExpression, does nothing // else, attempts to convert input to a big.Int (see utils.FromInterface) and returns a toVariable compiled.LinearExpression -func (system *compiler) toVariable(input interface{}) frontend.Variable { +func (system *r1cs) toVariable(input interface{}) frontend.Variable { switch t := input.(type) { case compiled.LinearExpression: @@ -610,7 +610,7 @@ func (system *compiler) toVariable(input interface{}) frontend.Variable { } // toVariables return frontend.Variable corresponding to inputs and the total size of the linear expressions -func (system *compiler) toVariables(in ...frontend.Variable) ([]compiled.LinearExpression, int) { +func (system *r1cs) toVariables(in ...frontend.Variable) ([]compiled.LinearExpression, int) { r := make([]compiled.LinearExpression, 0, len(in)) s := 0 e := func(i frontend.Variable) { @@ -628,7 +628,7 @@ func (system *compiler) toVariables(in ...frontend.Variable) ([]compiled.LinearE // Tag creates a tag at a given place in a circuit. The state of the tag may contain informations needed to // measure constraints, variables and coefficients creations through AddCounter -func (system *compiler) Tag(name string) frontend.Tag { +func (system *r1cs) Tag(name string) frontend.Tag { _, file, line, _ := runtime.Caller(1) return frontend.Tag{ @@ -639,7 +639,7 @@ func (system *compiler) Tag(name string) frontend.Tag { } // AddCounter measures the number of constraints, variables and coefficients created between two tags -func (system *compiler) AddCounter(from, to frontend.Tag) { +func (system *r1cs) AddCounter(from, to frontend.Tag) { system.Counters = append(system.Counters, compiled.Counter{ From: from.Name, To: to.Name, @@ -662,7 +662,7 @@ func (system *compiler) AddCounter(from, to frontend.Tag) { // // No new constraints are added to the newly created wire and must be added // manually in the circuit. Failing to do so leads to solver failure. -func (system *compiler) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { +func (system *r1cs) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { if nbOutputs <= 0 { return nil, fmt.Errorf("hint function must return at least one output") diff --git a/frontend/cs/scs/api.go b/frontend/cs/scs/api.go index 240fbd4f77..5a234cbded 100644 --- a/frontend/cs/scs/api.go +++ b/frontend/cs/scs/api.go @@ -32,7 +32,7 @@ import ( ) // Add returns res = i1+i2+...in -func (system *compiler) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { +func (system *scs) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { zero := big.NewInt(0) vars, k := system.filterConstantSum(append([]frontend.Variable{i1, i2}, in...)) @@ -52,7 +52,7 @@ func (system *compiler) Add(i1, i2 frontend.Variable, in ...frontend.Variable) f } // neg returns -in -func (system *compiler) neg(in []frontend.Variable) []frontend.Variable { +func (system *scs) neg(in []frontend.Variable) []frontend.Variable { res := make([]frontend.Variable, len(in)) @@ -63,13 +63,13 @@ func (system *compiler) neg(in []frontend.Variable) []frontend.Variable { } // Sub returns res = i1 - i2 - ...in -func (system *compiler) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { +func (system *scs) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { r := system.neg(append([]frontend.Variable{i2}, in...)) return system.Add(i1, r[0], r[1:]...) } // Neg returns -i -func (system *compiler) Neg(i1 frontend.Variable) frontend.Variable { +func (system *scs) Neg(i1 frontend.Variable) frontend.Variable { if n, ok := system.ConstantValue(i1); ok { n.Neg(n) // TODO shouldn't that go through variable conversion? @@ -87,7 +87,7 @@ func (system *compiler) Neg(i1 frontend.Variable) frontend.Variable { } // Mul returns res = i1 * i2 * ... in -func (system *compiler) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { +func (system *scs) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { vars, k := system.filterConstantProd(append([]frontend.Variable{i1, i2}, in...)) if len(vars) == 0 { @@ -99,7 +99,7 @@ func (system *compiler) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) f } // returns t*m -func (system *compiler) mulConstant(t compiled.Term, m *big.Int) compiled.Term { +func (system *scs) mulConstant(t compiled.Term, m *big.Int) compiled.Term { var coef big.Int cid, _, _ := t.Unpack() coef.Set(&system.st.Coeffs[cid]) @@ -110,7 +110,7 @@ func (system *compiler) mulConstant(t compiled.Term, m *big.Int) compiled.Term { } // DivUnchecked returns i1 / i2 . if i1 == i2 == 0, returns 0 -func (system *compiler) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable { +func (system *scs) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable { c1, i1Constant := system.ConstantValue(i1) c2, i2Constant := system.ConstantValue(i2) @@ -143,7 +143,7 @@ func (system *compiler) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable } // Div returns i1 / i2 -func (system *compiler) Div(i1, i2 frontend.Variable) frontend.Variable { +func (system *scs) Div(i1, i2 frontend.Variable) frontend.Variable { // note that here we ensure that v2 can't be 0, but it costs us one extra constraint system.Inverse(i2) @@ -152,7 +152,7 @@ func (system *compiler) Div(i1, i2 frontend.Variable) frontend.Variable { } // Inverse returns res = 1 / i1 -func (system *compiler) Inverse(i1 frontend.Variable) frontend.Variable { +func (system *scs) Inverse(i1 frontend.Variable) frontend.Variable { if c, ok := system.ConstantValue(i1); ok { c.ModInverse(c, system.CurveID.Info().Fr.Modulus()) return c @@ -173,7 +173,7 @@ func (system *compiler) Inverse(i1 frontend.Variable) frontend.Variable { // n default value is fr.Bits the number of bits needed to represent a field element // // The result in in little endian (first bit= lsb) -func (system *compiler) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable { +func (system *scs) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable { // nbBits nbBits := system.BitLen() @@ -197,7 +197,7 @@ func (system *compiler) ToBinary(i1 frontend.Variable, n ...int) []frontend.Vari return system.toBinary(a, nbBits, false) } -func (system *compiler) toBinary(a compiled.Term, nbBits int, unsafe bool) []frontend.Variable { +func (system *scs) toBinary(a compiled.Term, nbBits int, unsafe bool) []frontend.Variable { // allocate the resulting frontend.Variables and bit-constraint them sb := make([]frontend.Variable, nbBits) @@ -235,7 +235,7 @@ func (system *compiler) toBinary(a compiled.Term, nbBits int, unsafe bool) []fro } // FromBinary packs b, seen as a fr.Element in little endian -func (system *compiler) FromBinary(b ...frontend.Variable) frontend.Variable { +func (system *scs) FromBinary(b ...frontend.Variable) frontend.Variable { _b := make([]frontend.Variable, len(b)) var c big.Int c.SetUint64(1) @@ -254,7 +254,7 @@ func (system *compiler) FromBinary(b ...frontend.Variable) frontend.Variable { // Xor returns a ^ b // a and b must be 0 or 1 -func (system *compiler) Xor(a, b frontend.Variable) frontend.Variable { +func (system *scs) Xor(a, b frontend.Variable) frontend.Variable { _a, aConstant := system.ConstantValue(a) _b, bConstant := system.ConstantValue(b) @@ -285,7 +285,7 @@ func (system *compiler) Xor(a, b frontend.Variable) frontend.Variable { // Or returns a | b // a and b must be 0 or 1 -func (system *compiler) Or(a, b frontend.Variable) frontend.Variable { +func (system *scs) Or(a, b frontend.Variable) frontend.Variable { _a, aConstant := system.ConstantValue(a) _b, bConstant := system.ConstantValue(b) @@ -324,7 +324,7 @@ func (system *compiler) Or(a, b frontend.Variable) frontend.Variable { // Or returns a & b // a and b must be 0 or 1 -func (system *compiler) And(a, b frontend.Variable) frontend.Variable { +func (system *scs) And(a, b frontend.Variable) frontend.Variable { system.AssertIsBoolean(a) system.AssertIsBoolean(b) return system.Mul(a, b) @@ -334,7 +334,7 @@ func (system *compiler) And(a, b frontend.Variable) frontend.Variable { // Conditionals // Select if b is true, yields i1 else yields i2 -func (system *compiler) Select(b frontend.Variable, i1, i2 frontend.Variable) frontend.Variable { +func (system *scs) Select(b frontend.Variable, i1, i2 frontend.Variable) frontend.Variable { _b, bConstant := system.ConstantValue(b) if bConstant { @@ -356,7 +356,7 @@ func (system *compiler) Select(b frontend.Variable, i1, i2 frontend.Variable) fr // Lookup2 performs a 2-bit lookup between i1, i2, i3, i4 based on bits b0 // and b1. Returns i0 if b0=b1=0, i1 if b0=1 and b1=0, i2 if b0=0 and b1=1 // and i3 if b0=b1=1. -func (system *compiler) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 frontend.Variable) frontend.Variable { +func (system *scs) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 frontend.Variable) frontend.Variable { // vars, _ := system.toVariables(b0, b1, i0, i1, i2, i3) // s0, s1 := vars[0], vars[1] @@ -392,7 +392,7 @@ func (system *compiler) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 fronten } // IsZero returns 1 if a is zero, 0 otherwise -func (system *compiler) IsZero(i1 frontend.Variable) frontend.Variable { +func (system *scs) IsZero(i1 frontend.Variable) frontend.Variable { if a, ok := system.ConstantValue(i1); ok { if !(a.IsUint64() && a.Uint64() == 0) { panic("input should be zero") @@ -418,7 +418,7 @@ func (system *compiler) IsZero(i1 frontend.Variable) frontend.Variable { } // Cmp returns 1 if i1>i2, 0 if i1=i2, -1 if i1 bound -func (system *compiler) AssertIsLessOrEqual(v frontend.Variable, bound frontend.Variable) { +func (system *scs) AssertIsLessOrEqual(v frontend.Variable, bound frontend.Variable) { switch b := bound.(type) { case compiled.Term: system.mustBeLessOrEqVar(v.(compiled.Term), b) @@ -98,7 +98,7 @@ func (system *compiler) AssertIsLessOrEqual(v frontend.Variable, bound frontend. } } -func (system *compiler) mustBeLessOrEqVar(a compiled.Term, bound compiled.Term) { +func (system *scs) mustBeLessOrEqVar(a compiled.Term, bound compiled.Term) { debug := system.AddDebugInfo("mustBeLessOrEq", a, " <= ", bound) @@ -145,7 +145,7 @@ func (system *compiler) mustBeLessOrEqVar(a compiled.Term, bound compiled.Term) } -func (system *compiler) mustBeLessOrEqCst(a compiled.Term, bound big.Int) { +func (system *scs) mustBeLessOrEqCst(a compiled.Term, bound big.Int) { nbBits := system.BitLen() diff --git a/frontend/cs/scs/compiler.go b/frontend/cs/scs/compiler.go index cad6a0786a..fc7e9882d8 100644 --- a/frontend/cs/scs/compiler.go +++ b/frontend/cs/scs/compiler.go @@ -47,7 +47,7 @@ func NewCompiler(curve ecc.ID, config frontend.CompileConfig) (frontend.Builder, return newCompiler(curve, config), nil } -type compiler struct { +type scs struct { compiled.ConstraintSystem Constraints []compiled.SparseR1C @@ -60,8 +60,8 @@ type compiler struct { // initialCapacity has quite some impact on frontend performance, especially on large circuits size // we may want to add build tags to tune that -func newCompiler(curveID ecc.ID, config frontend.CompileConfig) *compiler { - system := compiler{ +func newCompiler(curveID ecc.ID, config frontend.CompileConfig) *scs { + system := scs{ ConstraintSystem: compiled.ConstraintSystem{ MDebug: make(map[int]int), @@ -83,7 +83,7 @@ func newCompiler(curveID ecc.ID, config frontend.CompileConfig) *compiler { // addPlonkConstraint creates a constraint of the for al+br+clr+k=0 //func (system *SparseR1CS) addPlonkConstraint(l, r, o frontend.Variable, cidl, cidr, cidm1, cidm2, cido, k int, debugID ...int) { -func (system *compiler) addPlonkConstraint(l, r, o compiled.Term, cidl, cidr, cidm1, cidm2, cido, k int, debugID ...int) { +func (system *scs) addPlonkConstraint(l, r, o compiled.Term, cidl, cidr, cidm1, cidm2, cido, k int, debugID ...int) { if len(debugID) > 0 { system.MDebug[len(system.Constraints)] = debugID[0] @@ -104,14 +104,14 @@ func (system *compiler) addPlonkConstraint(l, r, o compiled.Term, cidl, cidr, ci // newInternalVariable creates a new wire, appends it on the list of wires of the circuit, sets // the wire's id to the number of wires, and returns it -func (system *compiler) newInternalVariable() compiled.Term { +func (system *scs) newInternalVariable() compiled.Term { idx := system.NbInternalVariables system.NbInternalVariables++ return compiled.Pack(idx, compiled.CoeffIdOne, schema.Internal) } // AddPublicVariable creates a new Public Variable -func (system *compiler) AddPublicVariable(name string) frontend.Variable { +func (system *scs) AddPublicVariable(name string) frontend.Variable { if system.Schema != nil { panic("do not call AddPublicVariable in circuit.Define()") } @@ -121,7 +121,7 @@ func (system *compiler) AddPublicVariable(name string) frontend.Variable { } // AddSecretVariable creates a new Secret Variable -func (system *compiler) AddSecretVariable(name string) frontend.Variable { +func (system *scs) AddSecretVariable(name string) frontend.Variable { if system.Schema != nil { panic("do not call AddSecretVariable in circuit.Define()") } @@ -134,7 +134,7 @@ func (system *compiler) AddSecretVariable(name string) frontend.Variable { // It factorizes Variable that appears multiple times with != coeff Ids // To ensure the determinism in the compile process, Variables are stored as public∥secret∥internal∥unset // for each visibility, the Variables are sorted from lowest ID to highest ID -func (system *compiler) reduce(l compiled.LinearExpression) compiled.LinearExpression { +func (system *scs) reduce(l compiled.LinearExpression) compiled.LinearExpression { // ensure our linear expression is sorted, by visibility and by Variable ID sort.Sort(l) @@ -157,7 +157,7 @@ func (system *compiler) reduce(l compiled.LinearExpression) compiled.LinearExpre } // to handle wires that don't exist (=coef 0) in a sparse constraint -func (system *compiler) zero() compiled.Term { +func (system *scs) zero() compiled.Term { var a compiled.Term return a } @@ -165,7 +165,7 @@ func (system *compiler) zero() compiled.Term { // IsBoolean returns true if given variable was marked as boolean in the compiler (see MarkBoolean) // Use with care; variable may not have been **constrained** to be boolean // This returns true if the v is a constant and v == 0 || v == 1. -func (system *compiler) IsBoolean(v frontend.Variable) bool { +func (system *scs) IsBoolean(v frontend.Variable) bool { if b, ok := system.ConstantValue(v); ok { return b.IsUint64() && b.Uint64() <= 1 } @@ -176,7 +176,7 @@ func (system *compiler) IsBoolean(v frontend.Variable) bool { // MarkBoolean sets (but do not constraint!) v to be boolean // This is useful in scenarios where a variable is known to be boolean through a constraint // that is not api.AssertIsBoolean. If v is a constant, this is a no-op. -func (system *compiler) MarkBoolean(v frontend.Variable) { +func (system *scs) MarkBoolean(v frontend.Variable) { if b, ok := system.ConstantValue(v); ok { if !(b.IsUint64() && b.Uint64() <= 1) { panic("MarkBoolean called a non-boolean constant") @@ -189,7 +189,7 @@ func (system *compiler) MarkBoolean(v frontend.Variable) { // // 1. checks that all user inputs are referenced in at least one constraint // 2. checks that all hints are constrained -func (system *compiler) checkVariables() error { +func (system *scs) checkVariables() error { // TODO @gbotrel add unit test for that. @@ -294,7 +294,7 @@ func init() { tVariable = reflect.ValueOf(struct{ A frontend.Variable }{}).FieldByName("A").Type() } -func (cs *compiler) Compile() (frontend.CompiledConstraintSystem, error) { +func (cs *scs) Compile() (frontend.CompiledConstraintSystem, error) { // ensure all inputs and hints are constrained if !cs.config.IgnoreUnconstrainedInputs { @@ -410,7 +410,7 @@ HINTLOOP: } -func (cs *compiler) SetSchema(s *schema.Schema) { +func (cs *scs) SetSchema(s *schema.Schema) { if cs.Schema != nil { panic("SetSchema called multiple times") } @@ -515,7 +515,7 @@ func (b *levelBuilder) processTerm(t compiled.Term, cID int) { // ConstantValue returns the big.Int value of v. It // panics if v.IsConstant() == false -func (system *compiler) ConstantValue(v frontend.Variable) (*big.Int, bool) { +func (system *scs) ConstantValue(v frontend.Variable) (*big.Int, bool) { switch t := v.(type) { case compiled.Term: return nil, false @@ -525,13 +525,13 @@ func (system *compiler) ConstantValue(v frontend.Variable) (*big.Int, bool) { } } -func (system *compiler) Backend() backend.ID { +func (system *scs) Backend() backend.ID { return backend.PLONK } // Tag creates a tag at a given place in a circuit. The state of the tag may contain informations needed to // measure constraints, variables and coefficients creations through AddCounter -func (system *compiler) Tag(name string) frontend.Tag { +func (system *scs) Tag(name string) frontend.Tag { _, file, line, _ := runtime.Caller(1) return frontend.Tag{ @@ -544,7 +544,7 @@ func (system *compiler) Tag(name string) frontend.Tag { // AddCounter measures the number of constraints, variables and coefficients created between two tags // note that the PlonK statistics are contextual since there is a post-compile phase where linear expressions // are factorized. That is, measuring 2 times the "repeating" piece of circuit may give less constraints the second time -func (system *compiler) AddCounter(from, to frontend.Tag) { +func (system *scs) AddCounter(from, to frontend.Tag) { system.Counters = append(system.Counters, compiled.Counter{ From: from.Name, To: to.Name, @@ -567,7 +567,7 @@ func (system *compiler) AddCounter(from, to frontend.Tag) { // // No new constraints are added to the newly created wire and must be added // manually in the circuit. Failing to do so leads to solver failure. -func (system *compiler) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { +func (system *scs) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { if nbOutputs <= 0 { return nil, fmt.Errorf("hint function must return at least one output") @@ -604,7 +604,7 @@ func (system *compiler) NewHint(f hint.Function, nbOutputs int, inputs ...fronte } // returns in split into a slice of compiledTerm and the sum of all constants in in as a bigInt -func (system *compiler) filterConstantSum(in []frontend.Variable) (compiled.LinearExpression, big.Int) { +func (system *scs) filterConstantSum(in []frontend.Variable) (compiled.LinearExpression, big.Int) { res := make(compiled.LinearExpression, 0, len(in)) var b big.Int for i := 0; i < len(in); i++ { @@ -620,7 +620,7 @@ func (system *compiler) filterConstantSum(in []frontend.Variable) (compiled.Line } // returns in split into a slice of compiledTerm and the product of all constants in in as a bigInt -func (system *compiler) filterConstantProd(in []frontend.Variable) (compiled.LinearExpression, big.Int) { +func (system *scs) filterConstantProd(in []frontend.Variable) (compiled.LinearExpression, big.Int) { res := make(compiled.LinearExpression, 0, len(in)) var b big.Int b.SetInt64(1) @@ -636,7 +636,7 @@ func (system *compiler) filterConstantProd(in []frontend.Variable) (compiled.Lin return res, b } -func (system *compiler) splitSum(acc compiled.Term, r compiled.LinearExpression) compiled.Term { +func (system *scs) splitSum(acc compiled.Term, r compiled.LinearExpression) compiled.Term { // floor case if len(r) == 0 { @@ -650,7 +650,7 @@ func (system *compiler) splitSum(acc compiled.Term, r compiled.LinearExpression) return system.splitSum(o, r[1:]) } -func (system *compiler) splitProd(acc compiled.Term, r compiled.LinearExpression) compiled.Term { +func (system *scs) splitProd(acc compiled.Term, r compiled.LinearExpression) compiled.Term { // floor case if len(r) == 0 {