diff --git a/README.md b/README.md index 0a21565df2..d4fcf03e5b 100644 --- a/README.md +++ b/README.md @@ -83,7 +83,7 @@ func (circuit *CubicCircuit) Define(api frontend.API) error { // compiles our circuit into a R1CS var circuit CubicCircuit -ccs, err := frontend.Compile(ecc.BN254, backend.GROTH16, &circuit) +ccs, err := frontend.Compile(ecc.BN254, r1cs.NewCompiler, &circuit) // groth16 zkSNARK: Setup pk, vk, err := groth16.Setup(ccs) diff --git a/backend/groth16/groth16.go b/backend/groth16/groth16.go index c7f79554a5..1a286f5a3c 100644 --- a/backend/groth16/groth16.go +++ b/backend/groth16/groth16.go @@ -27,7 +27,6 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/r1cs" backend_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/cs" backend_bls12381 "github.com/consensys/gnark/internal/backend/bls12-381/cs" backend_bls24315 "github.com/consensys/gnark/internal/backend/bls24-315/cs" @@ -52,10 +51,6 @@ import ( groth16_bw6761 "github.com/consensys/gnark/internal/backend/bw6-761/groth16" ) -func init() { - frontend.RegisterDefaultBuilder(backend.GROTH16, r1cs.NewBuilder) -} - // TODO @gbotrel document hint functions here and in assert type groth16Object interface { diff --git a/backend/hint/builtin.go b/backend/hint/builtin.go index 9fcf91771f..7840c74a2d 100644 --- a/backend/hint/builtin.go +++ b/backend/hint/builtin.go @@ -11,13 +11,17 @@ var initBuiltinOnce sync.Once func init() { initBuiltinOnce.Do(func() { - IsZero = NewStaticHint(builtinIsZero, 1, 1) + IsZero = NewStaticHint(builtinIsZero) Register(IsZero) - IthBit = NewStaticHint(builtinIthBit, 2, 1) + IthBit = NewStaticHint(builtinIthBit) Register(IthBit) + NBits = NewStaticHint(builtinNBits) + Register(NBits) }) } +// TODO FIXME these may be redefined easily by an external package + // The package provides the following built-in hint functions. All built-in hint // functions are registered in the registry. var ( @@ -30,6 +34,9 @@ var ( // integer inputs i and n, takes the little-endian bit representation of n and // returns its i-th bit. IthBit Function + + // NBits returns the n first bits of the input. Expects one argument: n. + NBits Function ) func builtinIsZero(curveID ecc.ID, inputs []*big.Int, results []*big.Int) error { @@ -63,3 +70,11 @@ func builtinIthBit(_ ecc.ID, inputs []*big.Int, results []*big.Int) error { result.SetUint64(uint64(inputs[0].Bit(int(inputs[1].Uint64())))) return nil } + +func builtinNBits(_ ecc.ID, inputs []*big.Int, results []*big.Int) error { + n := inputs[0] + for i := 0; i < len(results); i++ { + results[i].SetUint64(uint64(n.Bit(i))) + } + return nil +} diff --git a/backend/hint/hint.go b/backend/hint/hint.go index a68f768b8d..a21bff6134 100644 --- a/backend/hint/hint.go +++ b/backend/hint/hint.go @@ -81,7 +81,7 @@ type ID uint32 // StaticFunction is a function which takes a constant number of inputs and // returns a constant number of outputs. Use NewStaticHint() to construct an // instance compatible with Function interface. -type StaticFunction func(curveID ecc.ID, inputs []*big.Int, res []*big.Int) error +type StaticFunction func(curveID ecc.ID, inputs []*big.Int, outputs []*big.Int) error // Function defines an annotated hint function. To initialize a hint function // with static number of inputs and outputs, use NewStaticHint(). @@ -91,21 +91,18 @@ type Function interface { UUID() ID // Call is invoked by the framework to obtain the result from inputs. - // The length of res is NbOutputs() and every element is - // already initialized (but not necessarily to zero as the elements may be - // obtained from cache). A returned non-nil error will be propagated. - Call(curveID ecc.ID, inputs []*big.Int, res []*big.Int) error - - // NbOutputs returns the total number of outputs by the function when - // invoked on the curveID with nInputs number of inputs. The number of - // outputs must be at least one and the framework errors otherwise. - NbOutputs(curveID ecc.ID, nInputs int) (nOutputs int) + // Elements in outputs are not guaranteed to be initialized to 0 + Call(curveID ecc.ID, inputs []*big.Int, outputs []*big.Int) error // String returns a human-readable description of the function used in logs // and debug messages. String() string } +func NewStaticHint(fn StaticFunction) Function { + return fn +} + // UUID is a reference function for computing the hint ID based on a function // and additional context values ctx. A change in any of the inputs modifies the // returned value and thus this function can be used to compute the hint ID for @@ -127,46 +124,16 @@ func UUID(fn StaticFunction, ctx ...uint64) ID { return ID(hf.Sum32()) } -// staticArgumentsFunction defines a function where the number of inputs and -// outputs is constant. -type staticArgumentsFunction struct { - fn StaticFunction - nIn int - nOut int -} - -// NewStaticHint returns an Function where the number of inputs and outputs is -// constant. UUID is computed by combining fn, nIn and nOut and thus it is legal -// to defined multiple AnnotatedFunctions on the same fn with different nIn and -// nOut. -func NewStaticHint(fn StaticFunction, nIn, nOut int) Function { - return &staticArgumentsFunction{ - fn: fn, - nIn: nIn, - nOut: nOut, - } -} - -func (h *staticArgumentsFunction) Call(curveID ecc.ID, inputs []*big.Int, res []*big.Int) error { - if len(inputs) != h.nIn { - return fmt.Errorf("input has %d elements, expected %d", len(inputs), h.nIn) - } - if len(res) != h.nOut { - return fmt.Errorf("result has %d elements, expected %d", len(res), h.nOut) - } - return h.fn(curveID, inputs, res) -} - -func (h *staticArgumentsFunction) NbOutputs(_ ecc.ID, _ int) int { - return h.nOut +func (h StaticFunction) Call(curveID ecc.ID, inputs []*big.Int, res []*big.Int) error { + return h(curveID, inputs, res) } -func (h *staticArgumentsFunction) UUID() ID { - return UUID(h.fn, uint64(h.nIn), uint64(h.nOut)) +func (h StaticFunction) UUID() ID { + return UUID(h) } -func (h *staticArgumentsFunction) String() string { - fnptr := reflect.ValueOf(h.fn).Pointer() +func (h StaticFunction) String() string { + fnptr := reflect.ValueOf(h).Pointer() name := runtime.FuncForPC(fnptr).Name() - return fmt.Sprintf("%s([%d]*big.Int, [%d]*big.Int) at (%x)", name, h.nIn, h.nOut, fnptr) + return fmt.Sprintf("%s([?]*big.Int, [?]*big.Int) at (%x)", name, fnptr) } diff --git a/backend/plonk/plonk.go b/backend/plonk/plonk.go index 158792bcb3..4676fe77b2 100644 --- a/backend/plonk/plonk.go +++ b/backend/plonk/plonk.go @@ -26,7 +26,6 @@ import ( "github.com/consensys/gnark-crypto/kzg" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/plonk" "github.com/consensys/gnark/backend/witness" cs_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/cs" @@ -58,10 +57,6 @@ import ( kzg_bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/kzg" ) -func init() { - frontend.RegisterDefaultBuilder(backend.PLONK, plonk.NewBuilder) -} - // Proof represents a Plonk proof generated by plonk.Prove // // it's underlying implementation is curve specific (see gnark/internal/backend) diff --git a/circuitstats_test.go b/circuitstats_test.go index 04e3b22964..c5bd5b73a7 100644 --- a/circuitstats_test.go +++ b/circuitstats_test.go @@ -9,6 +9,8 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/internal/backend/circuits" "github.com/consensys/gnark/test" ) @@ -26,19 +28,30 @@ func TestCircuitStatistics(t *testing.T) { for _, curve := range ecc.Implemented() { for _, b := range backend.Implemented() { curve := curve - b := b + backendID := b name := k // copy the circuit now in case assert calls t.Parallel() tData := circuits.Circuits[k] assert.Run(func(assert *test.Assert) { - ccs, err := frontend.Compile(curve, b, tData.Circuit) + var newCompiler frontend.NewCompiler + + switch backendID { + case backend.GROTH16: + newCompiler = r1cs.NewCompiler + case backend.PLONK: + newCompiler = scs.NewCompiler + default: + panic("not implemented") + } + + ccs, err := frontend.Compile(curve, newCompiler, tData.Circuit) assert.NoError(err) // ensure we didn't introduce regressions that make circuits less efficient nbConstraints := ccs.GetNbConstraints() internal, secret, public := ccs.GetNbVariables() - checkStats(assert, name, nbConstraints, internal, secret, public, curve, b) - }, name, curve.String(), b.String()) + checkStats(assert, name, nbConstraints, internal, secret, public, curve, backendID) + }, name, curve.String(), backendID.String()) } } diff --git a/debug_test.go b/debug_test.go index 51910507d2..126cc7abfa 100644 --- a/debug_test.go +++ b/debug_test.go @@ -10,6 +10,8 @@ import ( "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/backend/plonk" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/test" "github.com/stretchr/testify/require" ) @@ -43,11 +45,11 @@ func TestPrintln(t *testing.T) { witness.B = 11 var expected bytes.Buffer - expected.WriteString("debug_test.go:25 13 is the addition\n") - expected.WriteString("debug_test.go:27 26 42\n") - expected.WriteString("debug_test.go:29 bits 1\n") - expected.WriteString("debug_test.go:30 circuit {A: 2, B: 11}\n") - expected.WriteString("debug_test.go:34 m .*\n") + expected.WriteString("debug_test.go:27 13 is the addition\n") + expected.WriteString("debug_test.go:29 26 42\n") + expected.WriteString("debug_test.go:31 bits 1\n") + expected.WriteString("debug_test.go:32 circuit {A: 2, B: 11}\n") + expected.WriteString("debug_test.go:36 m .*\n") { trace, _ := getGroth16Trace(&circuit, &witness) @@ -172,7 +174,7 @@ func TestTraceNotBoolean(t *testing.T) { } func getPlonkTrace(circuit, w frontend.Circuit) (string, error) { - ccs, err := frontend.Compile(ecc.BN254, backend.PLONK, circuit) + ccs, err := frontend.Compile(ecc.BN254, scs.NewCompiler, circuit) if err != nil { return "", err } @@ -196,7 +198,7 @@ func getPlonkTrace(circuit, w frontend.Circuit) (string, error) { } func getGroth16Trace(circuit, w frontend.Circuit) (string, error) { - ccs, err := frontend.Compile(ecc.BN254, backend.GROTH16, circuit) + ccs, err := frontend.Compile(ecc.BN254, r1cs.NewCompiler, circuit) if err != nil { return "", err } diff --git a/examples/plonk/main.go b/examples/plonk/main.go index 80a6587822..31b809eff4 100644 --- a/examples/plonk/main.go +++ b/examples/plonk/main.go @@ -19,8 +19,8 @@ import ( "log" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/plonk" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/internal/backend/bn254/cs" "github.com/consensys/gnark/test" @@ -73,7 +73,7 @@ func main() { var circuit Circuit // // building the circuit... - ccs, err := frontend.Compile(ecc.BN254, backend.PLONK, &circuit) + ccs, err := frontend.Compile(ecc.BN254, scs.NewCompiler, &circuit) if err != nil { fmt.Println("circuit compilation error") } diff --git a/examples/rollup/circuit.go b/examples/rollup/circuit.go index 83f554c2b9..8d15a115ae 100644 --- a/examples/rollup/circuit.go +++ b/examples/rollup/circuit.go @@ -88,7 +88,7 @@ type TransferConstraints struct { func (circuit *Circuit) postInit(api frontend.API) error { // edward curve params - params, err := twistededwards.NewEdCurve(api.Curve()) + params, err := twistededwards.NewEdCurve(api.Compiler().Curve()) if err != nil { return err } diff --git a/examples/serialization/main.go b/examples/serialization/main.go index dfbcc77c85..ca0d7b11cd 100644 --- a/examples/serialization/main.go +++ b/examples/serialization/main.go @@ -6,9 +6,9 @@ import ( "github.com/fxamacker/cbor/v2" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/examples/cubic" ) @@ -17,7 +17,7 @@ func main() { var circuit cubic.Circuit // compile a circuit - _r1cs, _ := frontend.Compile(ecc.BN254, backend.GROTH16, &circuit) + _r1cs, _ := frontend.Compile(ecc.BN254, r1cs.NewCompiler, &circuit) // R1CS implements io.WriterTo and io.ReaderFrom var buf bytes.Buffer diff --git a/frontend/api.go b/frontend/api.go index 3e7fde87c0..2c1961eb15 100644 --- a/frontend/api.go +++ b/frontend/api.go @@ -52,6 +52,7 @@ type API interface { // --------------------------------------------------------------------------------------------- // Bit operations + // TODO @gbotrel move bit operations in std/math/bits // ToBinary unpacks a Variable in binary, // n is the number of bits to select (starting from lsb) @@ -111,39 +112,32 @@ type API interface { // whose value will be resolved at runtime when computed by the solver Println(a ...Variable) - // NewHint initializes internal variables whose value will be evaluated - // using the provided hint function at run time from the inputs. Inputs must - // be either variables or convertible to *big.Int. The function returns an - // error if the number of inputs is not compatible with f. - // - // The hint function is provided at the proof creation time and is not - // embedded into the circuit. From the backend point of view, the variable - // returned by the hint function is equivalent to the user-supplied witness, - // but its actual value is assigned by the solver, not the caller. - // - // No new constraints are added to the newly created wire and must be added - // manually in the circuit. Failing to do so leads to solver failure. - NewHint(f hint.Function, inputs ...Variable) ([]Variable, error) + // Compiler returns the compiler object for advanced circuit development + Compiler() Compiler + + // Deprecated APIs - // Tag creates a tag at a given place in a circuit. The state of the tag may contain informations needed to - // measure constraints, variables and coefficients creations through AddCounter + // NewHint is a shorcut to api.Compiler().NewHint() + // Deprecated: use api.Compiler().NewHint() instead + NewHint(f hint.Function, nbOutputs int, inputs ...Variable) ([]Variable, error) + + // Tag is a shorcut to api.Compiler().Tag() + // Deprecated: use api.Compiler().Tag() instead Tag(name string) Tag - // AddCounter measures the number of constraints, variables and coefficients created between two tags - // note that the PlonK statistics are contextual since there is a post-compile phase where linear expressions - // are factorized. That is, measuring 2 times the "repeating" piece of circuit may give less constraints the second time + // AddCounter is a shorcut to api.Compiler().AddCounter() + // Deprecated: use api.Compiler().AddCounter() instead AddCounter(from, to Tag) - // IsConstant returns true if v is a constant known at compile time - IsConstant(v Variable) bool - - // ConstantValue returns the big.Int value of v. It - // panics if v.IsConstant() == false - ConstantValue(v Variable) *big.Int + // ConstantValue is a shorcut to api.Compiler().ConstantValue() + // Deprecated: use api.Compiler().ConstantValue() instead + ConstantValue(v Variable) (*big.Int, bool) - // CurveID returns the ecc.ID injected by the compiler + // Curve is a shorcut to api.Compiler().Curve() + // Deprecated: use api.Compiler().Curve() instead Curve() ecc.ID - // Backend returns the backend.ID injected by the compiler + // Backend is a shorcut to api.Compiler().Backend() + // Deprecated: use api.Compiler().Backend() instead Backend() backend.ID } diff --git a/frontend/ccs.go b/frontend/ccs.go index d3659cc5c5..c7d91ca62a 100644 --- a/frontend/ccs.go +++ b/frontend/ccs.go @@ -20,8 +20,8 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" ) // CompiledConstraintSystem interface that a compiled (=typed, and correctly routed) diff --git a/frontend/compile.go b/frontend/compile.go index 4565063d0f..f934db4c98 100644 --- a/frontend/compile.go +++ b/frontend/compile.go @@ -6,29 +6,10 @@ import ( "reflect" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/debug" "github.com/consensys/gnark/frontend/schema" ) -var tVariable reflect.Type - -func init() { - tVariable = reflect.ValueOf(struct{ A Variable }{}).FieldByName("A").Type() -} - -// Builder represents a constraint system builder -type Builder interface { - API - CheckVariables() error - NewPublicVariable(name string) Variable - NewSecretVariable(name string) Variable - Compile() (CompiledConstraintSystem, error) - SetSchema(*schema.Schema) -} - -type NewBuilder func(ecc.ID) (Builder, error) - // Compile will generate a ConstraintSystem from the given circuit // // 1. it will first allocate the user inputs (see type Tag for more info) @@ -47,49 +28,33 @@ type NewBuilder func(ecc.ID) (Builder, error) // // initialCapacity is an optional parameter that reserves memory in slices // it should be set to the estimated number of constraints in the circuit, if known. -func Compile(curveID ecc.ID, zkpID backend.ID, circuit Circuit, opts ...CompileOption) (CompiledConstraintSystem, error) { - // setup option - opt := compileConfig{} +func Compile(curveID ecc.ID, newCompiler NewCompiler, circuit Circuit, opts ...CompileOption) (CompiledConstraintSystem, error) { + // parse options + opt := CompileConfig{} for _, o := range opts { if err := o(&opt); err != nil { return nil, fmt.Errorf("apply option: %w", err) } } - newBuilder := opt.newBuilder - if newBuilder == nil { - var ok bool - backendsM.RLock() - newBuilder, ok = backends[zkpID] - backendsM.RUnlock() - if !ok { - return nil, fmt.Errorf("no default constraint builder registered nor set as option") - } - } - builder, err := newBuilder(curveID) + + // instantiate new compiler + compiler, err := newCompiler(curveID, opt) if err != nil { - return nil, fmt.Errorf("new builder: %w", err) + return nil, fmt.Errorf("new compiler: %w", err) } - if err = bootstrap(builder, circuit); err != nil { - return nil, fmt.Errorf("bootstrap: %w", err) + // parse the circuit builds a schema of the circuit + // and call circuit.Define() method to initialize a list of constraints in the compiler + if err = parseCircuit(compiler, circuit); err != nil { + return nil, fmt.Errorf("parse circuit: %w", err) } - // ensure all inputs and hints are constrained - if !opt.ignoreUnconstrainedInputs { - if err := builder.CheckVariables(); err != nil { - return nil, err - } - } - - ccs, err := builder.Compile() - if err != nil { - return nil, fmt.Errorf("compile system: %w", err) - } - return ccs, nil + // compile the circuit into its final form + return compiler.Compile() } -func bootstrap(builder Builder, circuit Circuit) (err error) { +func parseCircuit(builder Builder, circuit Circuit) (err error) { // ensure circuit.Define has pointer receiver if reflect.ValueOf(circuit).Kind() != reflect.Ptr { return errors.New("frontend.Circuit methods must be defined on pointer receiver") @@ -101,9 +66,9 @@ func bootstrap(builder Builder, circuit Circuit) (err error) { if tInput.CanSet() { switch visibility { case schema.Secret: - tInput.Set(reflect.ValueOf(builder.NewSecretVariable(name))) + tInput.Set(reflect.ValueOf(builder.AddSecretVariable(name))) case schema.Public: - tInput.Set(reflect.ValueOf(builder.NewPublicVariable(name))) + tInput.Set(reflect.ValueOf(builder.AddPublicVariable(name))) case schema.Unset: return errors.New("can't set val " + name + " visibility is unset") } @@ -138,20 +103,19 @@ func bootstrap(builder Builder, circuit Circuit) (err error) { // CompileOption defines option for altering the behaviour of the Compile // method. See the descriptions of the functions returning instances of this // type for available options. -type CompileOption func(opt *compileConfig) error +type CompileOption func(opt *CompileConfig) error -type compileConfig struct { - capacity int - ignoreUnconstrainedInputs bool - newBuilder NewBuilder +type CompileConfig struct { + Capacity int + IgnoreUnconstrainedInputs bool } // WithCapacity is a compile option that specifies the estimated capacity needed // for internal variables and constraints. If not set, then the initial capacity // is 0 and is dynamically allocated as needed. func WithCapacity(capacity int) CompileOption { - return func(opt *compileConfig) error { - opt.capacity = capacity + return func(opt *CompileConfig) error { + opt.Capacity = capacity return nil } } @@ -164,19 +128,14 @@ func WithCapacity(capacity int) CompileOption { // production settings as it means that there is a potential error in the // circuit definition or that it is possible to optimize witness size. func IgnoreUnconstrainedInputs() CompileOption { - return func(opt *compileConfig) error { - opt.ignoreUnconstrainedInputs = true + return func(opt *CompileConfig) error { + opt.IgnoreUnconstrainedInputs = true return nil } } -// WithBuilder is a compile option which enables the compiler to build the -// constraint system with a user-defined builder. -// -// /!\ This is highly experimental and may change in upcoming releases /!\ -func WithBuilder(builder NewBuilder) CompileOption { - return func(opt *compileConfig) error { - opt.newBuilder = builder - return nil - } +var tVariable reflect.Type + +func init() { + tVariable = reflect.ValueOf(struct{ A Variable }{}).FieldByName("A").Type() } diff --git a/frontend/compiled/cs.go b/frontend/compiled/cs.go new file mode 100644 index 0000000000..ea6a87d2fe --- /dev/null +++ b/frontend/compiled/cs.go @@ -0,0 +1,122 @@ +package compiled + +import ( + "fmt" + "strings" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/debug" + "github.com/consensys/gnark/frontend/schema" + "github.com/consensys/gnark/internal/utils" +) + +// ConstraintSystem contains common element between R1CS and ConstraintSystem +type ConstraintSystem struct { + + // schema of the circuit + Schema *schema.Schema + + // number of wires + NbInternalVariables int + NbPublicVariables int + NbSecretVariables int + + // input wires names + Public, Secret []string + + // logs (added with cs.Println, resolved when solver sets a value to a wire) + Logs []LogEntry + + // debug info contains stack trace (including line number) of a call to a cs.API that + // results in an unsolved constraint + DebugInfo []LogEntry + + // maps constraint id to debugInfo id + // several constraints may point to the same debug info + MDebug map[int]int + + Counters []Counter // TODO @gbotrel no point in serializing these + + // maps wire id to hint + // a wire may point to at most one hint + MHints map[int]*Hint + + // each level contains independent constraints and can be parallelized + // it is guaranteed that all dependncies for constraints in a level l are solved + // in previous levels + Levels [][]int + + CurveID ecc.ID +} + +// GetNbVariables return number of internal, secret and public variables +func (cs *ConstraintSystem) GetNbVariables() (internal, secret, public int) { + return cs.NbInternalVariables, cs.NbSecretVariables, cs.NbPublicVariables +} + +// GetCounters return the collected constraint counters, if any +func (cs *ConstraintSystem) GetCounters() []Counter { return cs.Counters } + +func (cs *ConstraintSystem) GetSchema() *schema.Schema { return cs.Schema } + +// Counter contains measurements of useful statistics between two Tag +type Counter struct { + From, To string + NbVariables int + NbConstraints int + CurveID ecc.ID + BackendID backend.ID +} + +func (c Counter) String() string { + return fmt.Sprintf("%s[%s] %s - %s: %d variables, %d constraints", c.BackendID, c.CurveID, c.From, c.To, c.NbVariables, c.NbConstraints) +} + +func (cs *ConstraintSystem) Curve() ecc.ID { + return cs.CurveID +} + +func (cs *ConstraintSystem) AddDebugInfo(errName string, i ...interface{}) int { + + var l LogEntry + + const minLogSize = 500 + var sbb strings.Builder + sbb.Grow(minLogSize) + sbb.WriteString("[") + sbb.WriteString(errName) + sbb.WriteString("] ") + + for _, _i := range i { + switch v := _i.(type) { + case LinearExpression: + if len(v) > 1 { + sbb.WriteString("(") + } + l.WriteVariable(v, &sbb) + if len(v) > 1 { + sbb.WriteString(")") + } + case string: + sbb.WriteString(v) + case Term: + l.WriteTerm(v, &sbb) + default: + _v := utils.FromInterface(v) + sbb.WriteString(_v.String()) + } + } + sbb.WriteByte('\n') + debug.WriteStack(&sbb) + l.Format = sbb.String() + + cs.DebugInfo = append(cs.DebugInfo, l) + + return len(cs.DebugInfo) - 1 +} + +// bitLen returns the number of bits needed to represent a fr.Element +func (cs *ConstraintSystem) BitLen() int { + return cs.CurveID.Info().Fr.Bits +} diff --git a/internal/backend/compiled/cs.go b/frontend/compiled/hint.go similarity index 55% rename from internal/backend/compiled/cs.go rename to frontend/compiled/hint.go index b91dbba9f4..90ef4c8caa 100644 --- a/internal/backend/compiled/cs.go +++ b/frontend/compiled/hint.go @@ -2,49 +2,13 @@ package compiled import ( "fmt" - "io" "math/big" "reflect" - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/hint" - "github.com/consensys/gnark/frontend/schema" "github.com/fxamacker/cbor/v2" ) -// CS contains common element between R1CS and CS -type CS struct { - - // number of wires - NbInternalVariables int - NbPublicVariables int - NbSecretVariables int - - // logs (added with cs.Println, resolved when solver sets a value to a wire) - Logs []LogEntry - // debug info contains stack trace (including line number) of a call to a cs.API that - // results in an unsolved constraint - DebugInfo []LogEntry - - // maps constraint id to debugInfo id - // several constraints may point to the same debug info - MDebug map[int]int - - Counters []Counter // TODO @gbotrel no point in serializing these - // maps wire id to hint - - // a wire may point to at most one hint - MHints map[int]*Hint - - Schema *schema.Schema - - // each level contains independent constraints and can be parallelized - // it is guaranteed that all dependncies for constraints in a level l are solved - // in previous levels - Levels [][]int -} - // Hint represents a solver hint // it enables the solver to compute a Wire with a function provided at solving time // using pre-defined inputs @@ -60,9 +24,6 @@ func (h Hint) inputsCBORTags() (cbor.TagSet, error) { if err := tags.Add(defTagOpts, reflect.TypeOf(LinearExpression{}), 25443); err != nil { return nil, fmt.Errorf("new LE tag: %w", err) } - if err := tags.Add(defTagOpts, reflect.TypeOf(Variable{}), 25444); err != nil { - return nil, fmt.Errorf("new variable tag: %w", err) - } if err := tags.Add(defTagOpts, reflect.TypeOf(Term(0)), 25445); err != nil { return nil, fmt.Errorf("new term tag: %w", err) } @@ -136,12 +97,6 @@ func (h *Hint) UnmarshalCBOR(b []byte) error { return fmt.Errorf("unmarshal linear expression: %w", err) } inputs[i] = LinearExpression(v) - case 25444: - var v Variable - if err := dec.Unmarshal(vin.Content, &v); err != nil { - return fmt.Errorf("unmarshal variable: %w", err) - } - inputs[i] = v case 25445: var v uint64 if err := dec.Unmarshal(vin.Content, &v); err != nil { @@ -173,46 +128,3 @@ func (h *Hint) UnmarshalCBOR(b []byte) error { h.Wires = v.Wires return nil } - -// GetNbVariables return number of internal, secret and public variables -func (cs *CS) GetNbVariables() (internal, secret, public int) { - return cs.NbInternalVariables, cs.NbSecretVariables, cs.NbPublicVariables -} - -// FrSize panics -func (cs *CS) FrSize() int { panic("not implemented") } - -// GetNbCoefficients panics -func (cs *CS) GetNbCoefficients() int { panic("not implemented") } - -// CurveID returns ecc.UNKNOWN -func (cs *CS) CurveID() ecc.ID { return ecc.UNKNOWN } - -// WriteTo panics -func (cs *CS) WriteTo(w io.Writer) (n int64, err error) { panic("not implemented") } - -// ReadFrom panics -func (cs *CS) ReadFrom(r io.Reader) (n int64, err error) { panic("not implemented") } - -// ToHTML panics -func (cs *CS) ToHTML(w io.Writer) error { panic("not implemtened") } - -// GetCounters return the collected constraint counters, if any -func (cs *CS) GetCounters() []Counter { return cs.Counters } - -func (cs *CS) GetSchema() *schema.Schema { return cs.Schema } - -func (cs *CS) GetConstraints() [][]string { panic("not implemented") } - -// Counter contains measurements of useful statistics between two Tag -type Counter struct { - From, To string - NbVariables int - NbConstraints int - CurveID ecc.ID - BackendID backend.ID -} - -func (c Counter) String() string { - return fmt.Sprintf("%s[%s] %s - %s: %d variables, %d constraints", c.BackendID, c.CurveID, c.From, c.To, c.NbVariables, c.NbConstraints) -} diff --git a/frontend/compiled/linear_expression.go b/frontend/compiled/linear_expression.go new file mode 100644 index 0000000000..46baa9c868 --- /dev/null +++ b/frontend/compiled/linear_expression.go @@ -0,0 +1,87 @@ +// Copyright 2021 ConsenSys AG +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compiled + +import ( + "math/big" + "strings" +) + +// A linear expression is a linear combination of Term +type LinearExpression []Term + +// Clone returns a copy of the underlying slice +func (l LinearExpression) Clone() LinearExpression { + res := make(LinearExpression, len(l)) + copy(res, l) + return res +} + +func (l LinearExpression) string(sbb *strings.Builder, coeffs []big.Int) { + for i := 0; i < len(l); i++ { + l[i].string(sbb, coeffs) + if i+1 < len(l) { + sbb.WriteString(" + ") + } + } +} + +// Len return the lenght of the Variable (implements Sort interface) +func (l LinearExpression) Len() int { + return len(l) +} + +// Equals returns true if both SORTED expressions are the same +// +// pre conditions: l and o are sorted +func (l LinearExpression) Equal(o LinearExpression) bool { + if len(l) != len(o) { + return false + } + if (l == nil) != (o == nil) { + return false + } + for i := 0; i < len(l); i++ { + if l[i] != o[i] { + return false + } + } + return true +} + +// Swap swaps terms in the Variable (implements Sort interface) +func (l LinearExpression) Swap(i, j int) { + l[i], l[j] = l[j], l[i] +} + +// Less returns true if variableID for term at i is less than variableID for term at j (implements Sort interface) +func (l LinearExpression) Less(i, j int) bool { + _, iID, iVis := l[i].Unpack() + _, jID, jVis := l[j].Unpack() + if iVis == jVis { + return iID < jID + } + return iVis > jVis +} + +// HashCode returns a fast-to-compute but NOT collision resistant hash code identifier for the linear +// expression +func (l LinearExpression) HashCode() uint64 { + h := uint64(17) + for _, val := range l { + h = h*23 + uint64(val) + } + return h +} diff --git a/internal/backend/compiled/log.go b/frontend/compiled/log.go similarity index 88% rename from internal/backend/compiled/log.go rename to frontend/compiled/log.go index a578b80b92..b25676e006 100644 --- a/internal/backend/compiled/log.go +++ b/frontend/compiled/log.go @@ -28,14 +28,14 @@ type LogEntry struct { ToResolve []Term } -func (l *LogEntry) WriteVariable(le Variable, sbb *strings.Builder) { - sbb.Grow(len(le.LinExp) * len(" + (xx + xxxxxxxxxxxx")) +func (l *LogEntry) WriteVariable(le LinearExpression, sbb *strings.Builder) { + sbb.Grow(len(le) * len(" + (xx + xxxxxxxxxxxx")) - for i := 0; i < len(le.LinExp); i++ { + for i := 0; i < len(le); i++ { if i > 0 { sbb.WriteString(" + ") } - l.WriteTerm(le.LinExp[i], sbb) + l.WriteTerm(le[i], sbb) } } diff --git a/internal/backend/compiled/r1c.go b/frontend/compiled/r1cs.go similarity index 79% rename from internal/backend/compiled/r1c.go rename to frontend/compiled/r1cs.go index a2052de1d0..2bf3a0d8a6 100644 --- a/internal/backend/compiled/r1c.go +++ b/frontend/compiled/r1cs.go @@ -19,9 +19,20 @@ import ( "strings" ) +// R1CS decsribes a set of R1C constraint +type R1CS struct { + ConstraintSystem + Constraints []R1C +} + +// GetNbConstraints returns the number of constraints +func (r1cs *R1CS) GetNbConstraints() int { + return len(r1cs.Constraints) +} + // R1C used to compute the wires type R1C struct { - L, R, O Variable + L, R, O LinearExpression } func (r1c *R1C) String(coeffs []big.Int) string { diff --git a/internal/backend/compiled/r1c_sparse.go b/frontend/compiled/r1cs_sparse.go similarity index 84% rename from internal/backend/compiled/r1c_sparse.go rename to frontend/compiled/r1cs_sparse.go index c7adee2a06..20f554d6dd 100644 --- a/internal/backend/compiled/r1c_sparse.go +++ b/frontend/compiled/r1cs_sparse.go @@ -19,6 +19,17 @@ import ( "strings" ) +// R1CS decsribes a set of SparseR1C constraint +type SparseR1CS struct { + ConstraintSystem + Constraints []SparseR1C +} + +// GetNbConstraints returns the number of constraints +func (cs *SparseR1CS) GetNbConstraints() int { + return len(cs.Constraints) +} + // SparseR1C used to compute the wires // L+R+M[0]M[1]+O+k=0 // if a Term is zero, it means the field doesn't exist (ex M=[0,0] means there is no multiplicative term) diff --git a/internal/backend/compiled/term.go b/frontend/compiled/term.go similarity index 94% rename from internal/backend/compiled/term.go rename to frontend/compiled/term.go index 57966339b9..08148636e2 100644 --- a/internal/backend/compiled/term.go +++ b/frontend/compiled/term.go @@ -22,11 +22,9 @@ import ( "github.com/consensys/gnark/frontend/schema" ) -// Term lightweight version of a term, no pointers -// first 4 bits are reserved -// next 30 bits represented the coefficient idx (in r1cs.Coefficients) by which the wire is multiplied -// next 30 bits represent the constraint used to compute the wire -// if we support more than 1 billion constraints, this breaks (not so soon.) +// Term lightweight version of a term, no pointers. A term packs wireID, coeffID, visibility and +// some metadata associated with the term, in a uint64. +// note: if we support more than 1 billion constraints, this breaks (not so soon.) type Term uint64 // ids of the coefficients with simple values in any cs.coeffs slice. diff --git a/frontend/compiler.go b/frontend/compiler.go new file mode 100644 index 0000000000..866e9bc417 --- /dev/null +++ b/frontend/compiler.go @@ -0,0 +1,81 @@ +package frontend + +import ( + "math/big" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/frontend/schema" +) + +type NewCompiler func(ecc.ID, CompileConfig) (Builder, error) + +// Compiler represents a constraint system compiler +type Compiler interface { + // MarkBoolean sets (but do not constraint!) v to be boolean + // This is useful in scenarios where a variable is known to be boolean through a constraint + // that is not api.AssertIsBoolean. If v is a constant, this is a no-op. + MarkBoolean(v Variable) + + // IsBoolean returns true if given variable was marked as boolean in the compiler (see MarkBoolean) + // Use with care; variable may not have been **constrained** to be boolean + // This returns true if the v is a constant and v == 0 || v == 1. + IsBoolean(v Variable) bool + + // NewHint initializes internal variables whose value will be evaluated + // using the provided hint function at run time from the inputs. Inputs must + // be either variables or convertible to *big.Int. The function returns an + // error if the number of inputs is not compatible with f. + // + // The hint function is provided at the proof creation time and is not + // embedded into the circuit. From the backend point of view, the variable + // returned by the hint function is equivalent to the user-supplied witness, + // but its actual value is assigned by the solver, not the caller. + // + // No new constraints are added to the newly created wire and must be added + // manually in the circuit. Failing to do so leads to solver failure. + // + // If nbOutputs is specified, it must be >= 1 and <= f.NbOutputs + NewHint(f hint.Function, nbOutputs int, inputs ...Variable) ([]Variable, error) + + // Tag creates a tag at a given place in a circuit. The state of the tag may contain informations needed to + // measure constraints, variables and coefficients creations through AddCounter + Tag(name string) Tag + + // AddCounter measures the number of constraints, variables and coefficients created between two tags + // note that the PlonK statistics are contextual since there is a post-compile phase where linear expressions + // are factorized. That is, measuring 2 times the "repeating" piece of circuit may give less constraints the second time + AddCounter(from, to Tag) + + // ConstantValue returns the big.Int value of v and true if op is a success. + // nil and false if failure. This API returns a boolean to allow for future refactoring + // replacing *big.Int with fr.Element + ConstantValue(v Variable) (*big.Int, bool) + + // CurveID returns the ecc.ID injected by the compiler + Curve() ecc.ID + + // Backend returns the backend.ID injected by the compiler + Backend() backend.ID +} + +// Builder represents a constraint system builder +type Builder interface { + API + Compiler + + // Compile is called after circuit.Define() to produce a final IR (CompiledConstraintSystem) + Compile() (CompiledConstraintSystem, error) + + // SetSchema is used internally by frontend.Compile to set the circuit schema + SetSchema(*schema.Schema) + + // AddPublicVariable is called by the compiler when parsing the circuit schema. It panics if + // called inside circuit.Define() + AddPublicVariable(name string) Variable + + // AddSecretVariable is called by the compiler when parsing the circuit schema. It panics if + // called inside circuit.Define() + AddSecretVariable(name string) Variable +} diff --git a/frontend/cs/coeff_table.go b/frontend/cs/coeff_table.go new file mode 100644 index 0000000000..b3fba3ddee --- /dev/null +++ b/frontend/cs/coeff_table.go @@ -0,0 +1,75 @@ +package cs + +import ( + "math/big" + + "github.com/consensys/gnark/frontend/compiled" +) + +// CoeffTable helps build a constraint system but need not be serialized after compilation +type CoeffTable struct { + // Coefficients in the constraints + Coeffs []big.Int // list of unique coefficients. + CoeffsIDsLarge map[string]int // map to check existence of a coefficient (key = coeff.Bytes()) + CoeffsIDsInt64 map[int64]int // map to check existence of a coefficient (key = int64 value) +} + +func NewCoeffTable() CoeffTable { + st := CoeffTable{ + Coeffs: make([]big.Int, 4), + CoeffsIDsLarge: make(map[string]int), + CoeffsIDsInt64: make(map[int64]int, 4), + } + + st.Coeffs[compiled.CoeffIdZero].SetInt64(0) + st.Coeffs[compiled.CoeffIdOne].SetInt64(1) + st.Coeffs[compiled.CoeffIdTwo].SetInt64(2) + st.Coeffs[compiled.CoeffIdMinusOne].SetInt64(-1) + st.CoeffsIDsInt64[0] = compiled.CoeffIdZero + st.CoeffsIDsInt64[1] = compiled.CoeffIdOne + st.CoeffsIDsInt64[2] = compiled.CoeffIdTwo + st.CoeffsIDsInt64[-1] = compiled.CoeffIdMinusOne + + return st +} + +// CoeffID tries to fetch the entry where b is if it exits, otherwise appends b to +// the list of Coeffs and returns the corresponding entry +func (t *CoeffTable) CoeffID(v *big.Int) int { + + // if the coeff is a int64 we have a fast path. + if v.IsInt64() { + return t.coeffID64(v.Int64()) + } + + // GobEncode is 3x faster than b.Text(16). Slightly slower than Bytes, but Bytes return the same + // thing for -x and x . + bKey, _ := v.GobEncode() + key := string(bKey) + + // if the coeff is already stored, fetch its ID from the cs.CoeffsIDs map + if idx, ok := t.CoeffsIDsLarge[key]; ok { + return idx + } + + // else add it in the cs.Coeffs map and update the cs.CoeffsIDs map + var bCopy big.Int + bCopy.Set(v) + resID := len(t.Coeffs) + t.Coeffs = append(t.Coeffs, bCopy) + t.CoeffsIDsLarge[key] = resID + return resID +} + +func (t *CoeffTable) coeffID64(v int64) int { + if resID, ok := t.CoeffsIDsInt64[v]; ok { + return resID + } else { + var bCopy big.Int + bCopy.SetInt64(v) + resID := len(t.Coeffs) + t.Coeffs = append(t.Coeffs, bCopy) + t.CoeffsIDsInt64[v] = resID + return resID + } +} diff --git a/frontend/cs/cs.go b/frontend/cs/cs.go deleted file mode 100644 index 3bad64cd4d..0000000000 --- a/frontend/cs/cs.go +++ /dev/null @@ -1,135 +0,0 @@ -/* -Copyright © 2020 ConsenSys - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package cs - -import ( - "math/big" - "strings" - - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/debug" - "github.com/consensys/gnark/internal/backend/compiled" - "github.com/consensys/gnark/internal/utils" -) - -// ConstraintSystem contains the parts common to plonk and Groth16 -type ConstraintSystem struct { - compiled.CS - - // input wires - Public, Secret []string - - CurveID ecc.ID - // BackendID backend.ID - - // Coefficients in the constraints - Coeffs []big.Int // list of unique coefficients. - CoeffsIDsLarge map[string]int // map to check existence of a coefficient (key = coeff.Bytes()) - CoeffsIDsInt64 map[int64]int // map to check existence of a coefficient (key = int64 value) - - // map for recording boolean constrained variables (to not constrain them twice) - MTBooleans map[int]struct{} -} - -func (cs *ConstraintSystem) Curve() ecc.ID { - return cs.CurveID -} - -func (cs *ConstraintSystem) CoeffID64(v int64) int { - if resID, ok := cs.CoeffsIDsInt64[v]; ok { - return resID - } else { - var bCopy big.Int - bCopy.SetInt64(v) - resID := len(cs.Coeffs) - cs.Coeffs = append(cs.Coeffs, bCopy) - cs.CoeffsIDsInt64[v] = resID - return resID - } -} - -// CoeffID tries to fetch the entry where b is if it exits, otherwise appends b to -// the list of Coeffs and returns the corresponding entry -func (cs *ConstraintSystem) CoeffID(b *big.Int) int { - - // if the coeff is a int64 we have a fast path. - if b.IsInt64() { - return cs.CoeffID64(b.Int64()) - } - - // GobEncode is 3x faster than b.Text(16). Slightly slower than Bytes, but Bytes return the same - // thing for -x and x . - bKey, _ := b.GobEncode() - key := string(bKey) - - // if the coeff is already stored, fetch its ID from the cs.CoeffsIDs map - if idx, ok := cs.CoeffsIDsLarge[key]; ok { - return idx - } - - // else add it in the cs.Coeffs map and update the cs.CoeffsIDs map - var bCopy big.Int - bCopy.Set(b) - resID := len(cs.Coeffs) - cs.Coeffs = append(cs.Coeffs, bCopy) - cs.CoeffsIDsLarge[key] = resID - return resID -} - -func (cs *ConstraintSystem) AddDebugInfo(errName string, i ...interface{}) int { - - var l compiled.LogEntry - - const minLogSize = 500 - var sbb strings.Builder - sbb.Grow(minLogSize) - sbb.WriteString("[") - sbb.WriteString(errName) - sbb.WriteString("] ") - - for _, _i := range i { - switch v := _i.(type) { - case compiled.Variable: - if len(v.LinExp) > 1 { - sbb.WriteString("(") - } - l.WriteVariable(v, &sbb) - if len(v.LinExp) > 1 { - sbb.WriteString(")") - } - case string: - sbb.WriteString(v) - case compiled.Term: - l.WriteTerm(v, &sbb) - default: - _v := utils.FromInterface(v) - sbb.WriteString(_v.String()) - } - } - sbb.WriteByte('\n') - debug.WriteStack(&sbb) - l.Format = sbb.String() - - cs.DebugInfo = append(cs.DebugInfo, l) - - return len(cs.DebugInfo) - 1 -} - -// bitLen returns the number of bits needed to represent a fr.Element -func (cs *ConstraintSystem) BitLen() int { - return cs.CurveID.Info().Fr.Bits -} diff --git a/frontend/cs/plonk/conversion.go b/frontend/cs/plonk/conversion.go deleted file mode 100644 index 1f2e8d564e..0000000000 --- a/frontend/cs/plonk/conversion.go +++ /dev/null @@ -1,243 +0,0 @@ -/* -Copyright © 2021 ConsenSys Software Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package plonk - -import ( - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/schema" - bls12377r1cs "github.com/consensys/gnark/internal/backend/bls12-377/cs" - bls12381r1cs "github.com/consensys/gnark/internal/backend/bls12-381/cs" - bls24315r1cs "github.com/consensys/gnark/internal/backend/bls24-315/cs" - bn254r1cs "github.com/consensys/gnark/internal/backend/bn254/cs" - bw6633r1cs "github.com/consensys/gnark/internal/backend/bw6-633/cs" - bw6761r1cs "github.com/consensys/gnark/internal/backend/bw6-761/cs" - "github.com/consensys/gnark/internal/backend/compiled" -) - -func (cs *sparseR1CS) Compile() (frontend.CompiledConstraintSystem, error) { - - res := compiled.SparseR1CS{ - CS: cs.CS, - Constraints: cs.Constraints, - } - res.NbPublicVariables = len(cs.Public) - res.NbSecretVariables = len(cs.Secret) - - // Logs, DebugInfo and hints are copied, the only thing that will change - // is that ID of the wires will be offseted to take into account the final wire vector ordering - // that is: public wires | secret wires | internal wires - - // shift variable ID - // we want publicWires | privateWires | internalWires - shiftVID := func(oldID int, visibility schema.Visibility) int { - switch visibility { - case schema.Internal: - return oldID + res.NbPublicVariables + res.NbSecretVariables - case schema.Public: - return oldID - case schema.Secret: - return oldID + res.NbPublicVariables - default: - return oldID - } - } - - offsetTermID := func(t *compiled.Term) { - _, VID, visibility := t.Unpack() - t.SetWireID(shiftVID(VID, visibility)) - } - - // offset the IDs of all constraints so that the variables are - // numbered like this: [publicVariables | secretVariables | internalVariables ] - for i := 0; i < len(res.Constraints); i++ { - r1c := &res.Constraints[i] - offsetTermID(&r1c.L) - offsetTermID(&r1c.R) - offsetTermID(&r1c.O) - offsetTermID(&r1c.M[0]) - offsetTermID(&r1c.M[1]) - } - - // we need to offset the ids in Logs & DebugInfo - for i := 0; i < len(cs.Logs); i++ { - for j := 0; j < len(res.Logs[i].ToResolve); j++ { - offsetTermID(&res.Logs[i].ToResolve[j]) - } - } - for i := 0; i < len(cs.DebugInfo); i++ { - for j := 0; j < len(res.DebugInfo[i].ToResolve); j++ { - offsetTermID(&res.DebugInfo[i].ToResolve[j]) - } - } - - // we need to offset the ids in the hints - shiftedMap := make(map[int]*compiled.Hint) -HINTLOOP: - for _, hint := range cs.MHints { - ws := make([]int, len(hint.Wires)) - // we set for all outputs in shiftedMap. If one shifted output - // is in shiftedMap, then all are - for i, vID := range hint.Wires { - ws[i] = shiftVID(vID, schema.Internal) - if _, ok := shiftedMap[ws[i]]; i == 0 && ok { - continue HINTLOOP - } - } - inputs := make([]interface{}, len(hint.Inputs)) - copy(inputs, hint.Inputs) - for j := 0; j < len(inputs); j++ { - switch t := inputs[j].(type) { - case compiled.Term: - offsetTermID(&t) - inputs[j] = t // TODO check if we can remove it - default: - inputs[j] = t - } - } - ch := &compiled.Hint{ID: hint.ID, Inputs: inputs, Wires: ws} - for _, vID := range ws { - shiftedMap[vID] = ch - } - } - res.MHints = shiftedMap - - // build levels - res.Levels = buildLevels(res) - - switch cs.CurveID { - case ecc.BLS12_377: - return bls12377r1cs.NewSparseR1CS(res, cs.Coeffs), nil - case ecc.BLS12_381: - return bls12381r1cs.NewSparseR1CS(res, cs.Coeffs), nil - case ecc.BN254: - return bn254r1cs.NewSparseR1CS(res, cs.Coeffs), nil - case ecc.BW6_761: - return bw6761r1cs.NewSparseR1CS(res, cs.Coeffs), nil - case ecc.BLS24_315: - return bls24315r1cs.NewSparseR1CS(res, cs.Coeffs), nil - case ecc.BW6_633: - return bw6633r1cs.NewSparseR1CS(res, cs.Coeffs), nil - default: - panic("unknown curveID") - } - -} - -func (cs *sparseR1CS) SetSchema(s *schema.Schema) { - cs.Schema = s -} - -func buildLevels(ccs compiled.SparseR1CS) [][]int { - - b := levelBuilder{ - mWireToNode: make(map[int]int, ccs.NbInternalVariables), // at which node we resolved which wire - nodeLevels: make([]int, len(ccs.Constraints)), // level of a node - mLevels: make(map[int]int), // level counts - ccs: ccs, - nbInputs: ccs.NbPublicVariables + ccs.NbSecretVariables, - } - - // for each constraint, we're going to find its direct dependencies - // that is, wires (solved by previous constraints) on which it depends - // each of these dependencies is tagged with a level - // current constraint will be tagged with max(level) + 1 - for cID, c := range ccs.Constraints { - - b.nodeLevel = 0 - - b.processTerm(c.L, cID) - b.processTerm(c.R, cID) - b.processTerm(c.O, cID) - - b.nodeLevels[cID] = b.nodeLevel - b.mLevels[b.nodeLevel]++ - - } - - levels := make([][]int, len(b.mLevels)) - for i := 0; i < len(levels); i++ { - // allocate memory - levels[i] = make([]int, 0, b.mLevels[i]) - } - - for n, l := range b.nodeLevels { - levels[l] = append(levels[l], n) - } - - return levels -} - -type levelBuilder struct { - ccs compiled.SparseR1CS - nbInputs int - - mWireToNode map[int]int // at which node we resolved which wire - nodeLevels []int // level per node - mLevels map[int]int // number of constraint per level - - nodeLevel int // current level -} - -func (b *levelBuilder) processTerm(t compiled.Term, cID int) { - wID := t.WireID() - if wID < b.nbInputs { - // it's a input, we ignore it - return - } - - // if we know a which constraint solves this wire, then it's a dependency - n, ok := b.mWireToNode[wID] - if ok { - if n != cID { // can happen with hints... - // we add a dependency, check if we need to increment our current level - if b.nodeLevels[n] >= b.nodeLevel { - b.nodeLevel = b.nodeLevels[n] + 1 // we are at the next level at least since we depend on it - } - } - return - } - - // check if it's a hint and mark all the output wires - if h, ok := b.ccs.MHints[wID]; ok { - - for _, in := range h.Inputs { - switch t := in.(type) { - case compiled.Variable: - for _, tt := range t.LinExp { - b.processTerm(tt, cID) - } - case compiled.LinearExpression: - for _, tt := range t { - b.processTerm(tt, cID) - } - case compiled.Term: - b.processTerm(t, cID) - } - } - - for _, hwid := range h.Wires { - b.mWireToNode[hwid] = cID - } - - return - } - - // mark this wire solved by current node - b.mWireToNode[wID] = cID - -} diff --git a/frontend/cs/plonk/sparse_r1cs.go b/frontend/cs/plonk/sparse_r1cs.go deleted file mode 100644 index 2e4beac548..0000000000 --- a/frontend/cs/plonk/sparse_r1cs.go +++ /dev/null @@ -1,281 +0,0 @@ -/* -Copyright © 2021 ConsenSys Software Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package plonk - -import ( - "errors" - "math/big" - "reflect" - "sort" - "strconv" - "strings" - - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs" - "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" -) - -func NewBuilder(curve ecc.ID) (frontend.Builder, error) { - return newSparseR1CS(curve), nil -} - -type sparseR1CS struct { - cs.ConstraintSystem - - Constraints []compiled.SparseR1C -} - -// initialCapacity has quite some impact on frontend performance, especially on large circuits size -// we may want to add build tags to tune that -func newSparseR1CS(curveID ecc.ID, initialCapacity ...int) *sparseR1CS { - capacity := 0 - if len(initialCapacity) > 0 { - capacity = initialCapacity[0] - } - system := sparseR1CS{ - ConstraintSystem: cs.ConstraintSystem{ - - CS: compiled.CS{ - MDebug: make(map[int]int), - MHints: make(map[int]*compiled.Hint), - }, - - Coeffs: make([]big.Int, 4), - CoeffsIDsLarge: make(map[string]int), - CoeffsIDsInt64: make(map[int64]int, 4), - MTBooleans: make(map[int]struct{}), - }, - Constraints: make([]compiled.SparseR1C, 0, capacity), - } - - system.Coeffs[compiled.CoeffIdZero].SetInt64(0) - system.Coeffs[compiled.CoeffIdOne].SetInt64(1) - system.Coeffs[compiled.CoeffIdTwo].SetInt64(2) - system.Coeffs[compiled.CoeffIdMinusOne].SetInt64(-1) - - system.CoeffsIDsInt64[0] = compiled.CoeffIdZero - system.CoeffsIDsInt64[1] = compiled.CoeffIdOne - system.CoeffsIDsInt64[2] = compiled.CoeffIdTwo - system.CoeffsIDsInt64[-1] = compiled.CoeffIdMinusOne - - // system.public.variables = make([]Variable, 0) - // system.secret.variables = make([]Variable, 0) - // system.internal = make([]Variable, 0, capacity) - system.Public = make([]string, 0) - system.Secret = make([]string, 0) - - system.CurveID = curveID - - return &system -} - -// addPlonkConstraint creates a constraint of the for al+br+clr+k=0 -//func (system *SparseR1CS) addPlonkConstraint(l, r, o frontend.Variable, cidl, cidr, cidm1, cidm2, cido, k int, debugID ...int) { -func (system *sparseR1CS) addPlonkConstraint(l, r, o compiled.Term, cidl, cidr, cidm1, cidm2, cido, k int, debugID ...int) { - - if len(debugID) > 0 { - system.MDebug[len(system.Constraints)] = debugID[0] - } - - l.SetCoeffID(cidl) - r.SetCoeffID(cidr) - o.SetCoeffID(cido) - - u := l - v := r - u.SetCoeffID(cidm1) - v.SetCoeffID(cidm2) - - //system.Constraints = append(system.Constraints, compiled.SparseR1C{L: _l, R: _r, O: _o, M: [2]compiled.Term{u, v}, K: k}) - system.Constraints = append(system.Constraints, compiled.SparseR1C{L: l, R: r, O: o, M: [2]compiled.Term{u, v}, K: k}) -} - -// newInternalVariable creates a new wire, appends it on the list of wires of the circuit, sets -// the wire's id to the number of wires, and returns it -func (system *sparseR1CS) newInternalVariable() compiled.Term { - idx := system.NbInternalVariables - system.NbInternalVariables++ - return compiled.Pack(idx, compiled.CoeffIdOne, schema.Internal) -} - -// NewPublicVariable creates a new Public Variable -func (system *sparseR1CS) NewPublicVariable(name string) frontend.Variable { - idx := len(system.Public) - system.Public = append(system.Public, name) - return compiled.Pack(idx, compiled.CoeffIdOne, schema.Public) -} - -// NewPublicVariable creates a new Secret Variable -func (system *sparseR1CS) NewSecretVariable(name string) frontend.Variable { - idx := len(system.Secret) - system.Secret = append(system.Secret, name) - return compiled.Pack(idx, compiled.CoeffIdOne, schema.Secret) -} - -// reduces redundancy in linear expression -// It factorizes Variable that appears multiple times with != coeff Ids -// To ensure the determinism in the compile process, Variables are stored as public∥secret∥internal∥unset -// for each visibility, the Variables are sorted from lowest ID to highest ID -func (system *sparseR1CS) reduce(l compiled.LinearExpression) compiled.LinearExpression { - - // ensure our linear expression is sorted, by visibility and by Variable ID - sort.Sort(l) - - mod := system.CurveID.Info().Fr.Modulus() - c := new(big.Int) - for i := 1; i < len(l); i++ { - pcID, pvID, pVis := l[i-1].Unpack() - ccID, cvID, cVis := l[i].Unpack() - if pVis == cVis && pvID == cvID { - // we have redundancy - c.Add(&system.Coeffs[pcID], &system.Coeffs[ccID]) - c.Mod(c, mod) - l[i-1].SetCoeffID(system.CoeffID(c)) - l = append(l[:i], l[i+1:]...) - i-- - } - } - return l -} - -// to handle wires that don't exist (=coef 0) in a sparse constraint -func (system *sparseR1CS) zero() compiled.Term { - var a compiled.Term - return a -} - -// returns true if a variable is already boolean -func (system *sparseR1CS) isBoolean(t compiled.Term) bool { - _, ok := system.MTBooleans[int(t)] - return ok -} - -// markBoolean records t in the map to not boolean constrain it twice -func (system *sparseR1CS) markBoolean(t compiled.Term) { - system.MTBooleans[int(t)] = struct{}{} -} - -// checkVariables perform post compilation checks on the Variables -// -// 1. checks that all user inputs are referenced in at least one constraint -// 2. checks that all hints are constrained -func (system *sparseR1CS) CheckVariables() error { - - // TODO @gbotrel add unit test for that. - - cptSecret := len(system.Secret) - cptPublic := len(system.Public) - cptHints := len(system.MHints) - - // compared to R1CS, we may have a circuit which does not have any inputs - // (R1CS always has a constant ONE wire). Check the edge case and omit any - // processing if so. - if cptSecret+cptPublic+cptHints == 0 { - return nil - } - - secretConstrained := make([]bool, cptSecret) - publicConstrained := make([]bool, cptPublic) - - mHintsConstrained := make(map[int]bool) - - // for each constraint, we check the terms and mark our inputs / hints as constrained - processTerm := func(t compiled.Term) { - - // L and M[0] handles the same wire but with a different coeff - visibility := t.VariableVisibility() - vID := t.WireID() - if t.CoeffID() != compiled.CoeffIdZero { - switch visibility { - case schema.Public: - if !publicConstrained[vID] { - publicConstrained[vID] = true - cptPublic-- - } - case schema.Secret: - if !secretConstrained[vID] { - secretConstrained[vID] = true - cptSecret-- - } - case schema.Internal: - if _, ok := system.MHints[vID]; !mHintsConstrained[vID] && ok { - mHintsConstrained[vID] = true - cptHints-- - } - } - } - - } - for _, c := range system.Constraints { - processTerm(c.L) - processTerm(c.R) - processTerm(c.M[0]) - processTerm(c.M[1]) - processTerm(c.O) - if cptHints|cptSecret|cptPublic == 0 { - return nil // we can stop. - } - - } - - // something is a miss, we build the error string - var sbb strings.Builder - if cptSecret != 0 { - sbb.WriteString(strconv.Itoa(cptSecret)) - sbb.WriteString(" unconstrained secret input(s):") - sbb.WriteByte('\n') - for i := 0; i < len(secretConstrained) && cptSecret != 0; i++ { - if !secretConstrained[i] { - sbb.WriteString(system.Secret[i]) - sbb.WriteByte('\n') - cptSecret-- - } - } - sbb.WriteByte('\n') - } - - if cptPublic != 0 { - sbb.WriteString(strconv.Itoa(cptPublic)) - sbb.WriteString(" unconstrained public input(s):") - sbb.WriteByte('\n') - for i := 0; i < len(publicConstrained) && cptPublic != 0; i++ { - if !publicConstrained[i] { - sbb.WriteString(system.Public[i]) - sbb.WriteByte('\n') - cptPublic-- - } - } - sbb.WriteByte('\n') - } - - if cptHints != 0 { - sbb.WriteString(strconv.Itoa(cptHints)) - sbb.WriteString(" unconstrained hints") - sbb.WriteByte('\n') - // TODO we may add more debug info here → idea, in NewHint, take the debug stack, and store in the hint map some - // debugInfo to find where a hint was declared (and not constrained) - } - return errors.New(sbb.String()) -} - -var tVariable reflect.Type - -func init() { - tVariable = reflect.ValueOf(struct{ A frontend.Variable }{}).FieldByName("A").Type() -} diff --git a/frontend/cs/r1cs/api.go b/frontend/cs/r1cs/api.go index 9fff35cd11..e437a5bce2 100644 --- a/frontend/cs/r1cs/api.go +++ b/frontend/cs/r1cs/api.go @@ -25,30 +25,27 @@ import ( "strconv" "strings" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" - "github.com/consensys/gnark/internal/utils" ) // --------------------------------------------------------------------------------------------- // Arithmetic // Add returns res = i1+i2+...in -func (system *r1CS) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { +func (system *r1cs) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { // extract frontend.Variables from input vars, s := system.toVariables(append([]frontend.Variable{i1, i2}, in...)...) // allocate resulting frontend.Variable - t := false - res := compiled.Variable{LinExp: make([]compiled.Term, 0, s), IsBoolean: &t} + res := make(compiled.LinearExpression, 0, s) for _, v := range vars { l := v.Clone() - res.LinExp = append(res.LinExp, l.LinExp...) + res = append(res, l...) } res = system.reduce(res) @@ -57,39 +54,31 @@ func (system *r1CS) Add(i1, i2 frontend.Variable, in ...frontend.Variable) front } // Neg returns -i -func (system *r1CS) Neg(i frontend.Variable) frontend.Variable { +func (system *r1cs) Neg(i frontend.Variable) frontend.Variable { vars, _ := system.toVariables(i) - if vars[0].IsConstant() { - n := system.constantValue(vars[0]) + if n, ok := system.ConstantValue(vars[0]); ok { n.Neg(n) - return system.constant(n) + return system.toVariable(n) } - // ok to pass pointer since if i is boolean constrained later, so must be res - res := compiled.Variable{LinExp: system.negateLinExp(vars[0].LinExp), IsBoolean: vars[0].IsBoolean} - - return res + return system.negateLinExp(vars[0]) } // Sub returns res = i1 - i2 -func (system *r1CS) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { +func (system *r1cs) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { // extract frontend.Variables from input vars, s := system.toVariables(append([]frontend.Variable{i1, i2}, in...)...) // allocate resulting frontend.Variable - t := false - res := compiled.Variable{ - LinExp: make([]compiled.Term, 0, s), - IsBoolean: &t, - } + res := make(compiled.LinearExpression, 0, s) c := vars[0].Clone() - res.LinExp = append(res.LinExp, c.LinExp...) + res = append(res, c...) for i := 1; i < len(vars); i++ { - negLinExp := system.negateLinExp(vars[i].LinExp) - res.LinExp = append(res.LinExp, negLinExp...) + negLinExp := system.negateLinExp(vars[i]) + res = append(res, negLinExp...) } // reduce linear expression @@ -99,29 +88,29 @@ func (system *r1CS) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) front } // Mul returns res = i1 * i2 * ... in -func (system *r1CS) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { +func (system *r1cs) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { vars, _ := system.toVariables(append([]frontend.Variable{i1, i2}, in...)...) - mul := func(v1, v2 compiled.Variable) compiled.Variable { + mul := func(v1, v2 compiled.LinearExpression) compiled.LinearExpression { + + n1, v1Constant := system.ConstantValue(v1) + n2, v2Constant := system.ConstantValue(v2) // v1 and v2 are both unknown, this is the only case we add a constraint - if !v1.IsConstant() && !v2.IsConstant() { + if !v1Constant && !v2Constant { res := system.newInternalVariable() system.Constraints = append(system.Constraints, newR1C(v1, v2, res)) return res } // v1 and v2 are constants, we multiply big.Int values and return resulting constant - if v1.IsConstant() && v2.IsConstant() { - b1 := system.constantValue(v1) - b2 := system.constantValue(v2) - - b1.Mul(b1, b2).Mod(b1, system.CurveID.Info().Fr.Modulus()) - return system.constant(b1).(compiled.Variable) + if v1Constant && v2Constant { + n1.Mul(n1, n2).Mod(n1, system.CurveID.Info().Fr.Modulus()) + return system.toVariable(n1).(compiled.LinearExpression) } // ensure v2 is the constant - if v1.IsConstant() { + if v1Constant { v1, v2 = v2, v1 } @@ -137,13 +126,13 @@ func (system *r1CS) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) front return res } -func (system *r1CS) mulConstant(v1, constant compiled.Variable) compiled.Variable { +func (system *r1cs) mulConstant(v1, constant compiled.LinearExpression) compiled.LinearExpression { // multiplying a frontend.Variable by a constant -> we updated the coefficients in the linear expression // leading to that frontend.Variable res := v1.Clone() - lambda := system.constantValue(constant) + lambda, _ := system.ConstantValue(constant) - for i, t := range v1.LinExp { + for i, t := range v1 { cID, vID, visibility := t.Unpack() var newCoeff big.Int switch cID { @@ -156,23 +145,24 @@ func (system *r1CS) mulConstant(v1, constant compiled.Variable) compiled.Variabl case compiled.CoeffIdTwo: newCoeff.Add(lambda, lambda) default: - coeff := system.Coeffs[cID] + coeff := system.st.Coeffs[cID] newCoeff.Mul(&coeff, lambda) } - res.LinExp[i] = compiled.Pack(vID, system.CoeffID(&newCoeff), visibility) + res[i] = compiled.Pack(vID, system.st.CoeffID(&newCoeff), visibility) } - t := false - res.IsBoolean = &t return res } -func (system *r1CS) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable { +func (system *r1cs) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable { vars, _ := system.toVariables(i1, i2) v1 := vars[0] v2 := vars[1] - if !v2.IsConstant() { + n1, v1Constant := system.ConstantValue(v1) + n2, v2Constant := system.ConstantValue(v2) + + if !v2Constant { res := system.newInternalVariable() debug := system.AddDebugInfo("div", v1, "/", v2, " == ", res) // note that here we don't ensure that divisor is != 0 @@ -181,30 +171,32 @@ func (system *r1CS) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable { } // v2 is constant - b2 := system.constantValue(v2) - if b2.IsUint64() && b2.Uint64() == 0 { + if n2.IsUint64() && n2.Uint64() == 0 { panic("div by constant(0)") } q := system.CurveID.Info().Fr.Modulus() - b2.ModInverse(b2, q) + n2.ModInverse(n2, q) - if v1.IsConstant() { - b2.Mul(b2, system.constantValue(v1)).Mod(b2, q) - return system.constant(b2) + if v1Constant { + n2.Mul(n2, n1).Mod(n2, q) + return system.toVariable(n2) } // v1 is not constant - return system.mulConstant(v1, system.constant(b2).(compiled.Variable)) + return system.mulConstant(v1, system.toVariable(n2).(compiled.LinearExpression)) } // Div returns res = i1 / i2 -func (system *r1CS) Div(i1, i2 frontend.Variable) frontend.Variable { +func (system *r1cs) Div(i1, i2 frontend.Variable) frontend.Variable { vars, _ := system.toVariables(i1, i2) v1 := vars[0] v2 := vars[1] - if !v2.IsConstant() { + n1, v1Constant := system.ConstantValue(v1) + n2, v2Constant := system.ConstantValue(v2) + + if !v2Constant { res := system.newInternalVariable() debug := system.AddDebugInfo("div", v1, "/", v2, " == ", res) v2Inv := system.newInternalVariable() @@ -215,35 +207,32 @@ func (system *r1CS) Div(i1, i2 frontend.Variable) frontend.Variable { } // v2 is constant - b2 := system.constantValue(v2) - if b2.IsUint64() && b2.Uint64() == 0 { + if n2.IsUint64() && n2.Uint64() == 0 { panic("div by constant(0)") } q := system.CurveID.Info().Fr.Modulus() - b2.ModInverse(b2, q) + n2.ModInverse(n2, q) - if v1.IsConstant() { - b2.Mul(b2, system.constantValue(v1)).Mod(b2, q) - return system.constant(b2) + if v1Constant { + n2.Mul(n2, n1).Mod(n2, q) + return system.toVariable(n2) } // v1 is not constant - return system.mulConstant(v1, system.constant(b2).(compiled.Variable)) + return system.mulConstant(v1, system.toVariable(n2).(compiled.LinearExpression)) } // Inverse returns res = inverse(v) -func (system *r1CS) Inverse(i1 frontend.Variable) frontend.Variable { +func (system *r1cs) Inverse(i1 frontend.Variable) frontend.Variable { vars, _ := system.toVariables(i1) - if vars[0].IsConstant() { - // c := vars[0].constantValue(cs) - c := system.constantValue(vars[0]) + if c, ok := system.ConstantValue(vars[0]); ok { if c.IsUint64() && c.Uint64() == 0 { panic("inverse by constant(0)") } c.ModInverse(c, system.CurveID.Info().Fr.Modulus()) - return system.constant(c) + return system.toVariable(c) } // allocate resulting frontend.Variable @@ -263,7 +252,7 @@ func (system *r1CS) Inverse(i1 frontend.Variable) frontend.Variable { // n default value is fr.Bits the number of bits needed to represent a field element // // The result in in little endian (first bit= lsb) -func (system *r1CS) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable { +func (system *r1cs) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable { // nbBits nbBits := system.BitLen() @@ -278,50 +267,46 @@ func (system *r1CS) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable a := vars[0] // if a is a constant, work with the big int value. - if a.IsConstant() { - c := system.constantValue(a) - b := make([]compiled.Variable, nbBits) + if c, ok := system.ConstantValue(a); ok { + b := make([]frontend.Variable, nbBits) for i := 0; i < len(b); i++ { - b[i] = system.constant(c.Bit(i)).(compiled.Variable) + b[i] = system.toVariable(c.Bit(i)) } - return toSliceOfVariables(b) + return b } return system.toBinary(a, nbBits, false) } // toBinary is equivalent to ToBinary, exept the returned bits are NOT boolean constrained. -func (system *r1CS) toBinary(a compiled.Variable, nbBits int, unsafe bool) []frontend.Variable { +func (system *r1cs) toBinary(a compiled.LinearExpression, nbBits int, unsafe bool) []frontend.Variable { - if a.IsConstant() { + if _, ok := system.ConstantValue(a); ok { return system.ToBinary(a, nbBits) } - // ensure a is set - a.AssertIsSet() - // allocate the resulting frontend.Variables and bit-constraint them - b := make([]frontend.Variable, nbBits) sb := make([]frontend.Variable, nbBits) var c big.Int c.SetUint64(1) + + bits, err := system.NewHint(hint.NBits, nbBits, a) + if err != nil { + panic(err) + } + for i := 0; i < nbBits; i++ { - res, err := system.NewHint(hint.IthBit, a, i) - if err != nil { - panic(err) - } - b[i] = res[0] - sb[i] = system.Mul(b[i], c) + sb[i] = system.Mul(bits[i], c) c.Lsh(&c, 1) if !unsafe { - system.AssertIsBoolean(b[i]) + system.AssertIsBoolean(bits[i]) } } - //var Σbi compiled.Variable + //var Σbi compiled.LinearExpression var Σbi frontend.Variable if nbBits == 1 { - system.AssertIsEqual(sb[0], a) + Σbi = sb[0] } else if nbBits == 2 { Σbi = system.Add(sb[0], sb[1]) } else { @@ -330,37 +315,23 @@ func (system *r1CS) toBinary(a compiled.Variable, nbBits int, unsafe bool) []fro system.AssertIsEqual(Σbi, a) // record the constraint Σ (2**i * b[i]) == a - return b + return bits } -func toSliceOfVariables(v []compiled.Variable) []frontend.Variable { - // TODO this is ugly. - r := make([]frontend.Variable, len(v)) - for i := 0; i < len(v); i++ { - r[i] = v[i] - } - return r -} - // FromBinary packs b, seen as a fr.Element in little endian -func (system *r1CS) FromBinary(_b ...frontend.Variable) frontend.Variable { +func (system *r1cs) FromBinary(_b ...frontend.Variable) frontend.Variable { b, _ := system.toVariables(_b...) - // ensure inputs are set - for i := 0; i < len(b); i++ { - b[i].AssertIsSet() - } - // res = Σ (2**i * b[i]) var res, v frontend.Variable - res = system.constant(0) // no constraint is recorded + res = system.toVariable(0) // no constraint is recorded var c big.Int c.SetUint64(1) - L := make([]compiled.Term, len(b)) + L := make(compiled.LinearExpression, len(b)) for i := 0; i < len(L); i++ { v = system.Mul(c, b[i]) // no constraint is recorded res = system.Add(v, res) // no constraint is recorded @@ -372,7 +343,7 @@ func (system *r1CS) FromBinary(_b ...frontend.Variable) frontend.Variable { } // Xor compute the XOR between two frontend.Variables -func (system *r1CS) Xor(_a, _b frontend.Variable) frontend.Variable { +func (system *r1cs) Xor(_a, _b frontend.Variable) frontend.Variable { vars, _ := system.toVariables(_a, _b) @@ -384,12 +355,9 @@ func (system *r1CS) Xor(_a, _b frontend.Variable) frontend.Variable { // the formulation used is for easing up the conversion to sparse r1cs res := system.newInternalVariable() - res.IsBoolean = new(bool) - *res.IsBoolean = true - c := system.Neg(res).(compiled.Variable) - c.IsBoolean = new(bool) - *c.IsBoolean = false - c.LinExp = append(c.LinExp, a.LinExp[0], b.LinExp[0]) + system.MarkBoolean(res) + c := system.Neg(res).(compiled.LinearExpression) + c = append(c, a[0], b[0]) aa := system.Mul(a, 2) system.Constraints = append(system.Constraints, newR1C(aa, b, c)) @@ -397,7 +365,7 @@ func (system *r1CS) Xor(_a, _b frontend.Variable) frontend.Variable { } // Or compute the OR between two frontend.Variables -func (system *r1CS) Or(_a, _b frontend.Variable) frontend.Variable { +func (system *r1cs) Or(_a, _b frontend.Variable) frontend.Variable { vars, _ := system.toVariables(_a, _b) a := vars[0] @@ -408,19 +376,16 @@ func (system *r1CS) Or(_a, _b frontend.Variable) frontend.Variable { // the formulation used is for easing up the conversion to sparse r1cs res := system.newInternalVariable() - res.IsBoolean = new(bool) - *res.IsBoolean = true - c := system.Neg(res).(compiled.Variable) - c.IsBoolean = new(bool) - *c.IsBoolean = false - c.LinExp = append(c.LinExp, a.LinExp[0], b.LinExp[0]) + system.MarkBoolean(res) + c := system.Neg(res).(compiled.LinearExpression) + c = append(c, a[0], b[0]) system.Constraints = append(system.Constraints, newR1C(a, b, c)) return res } // And compute the AND between two frontend.Variables -func (system *r1CS) And(_a, _b frontend.Variable) frontend.Variable { +func (system *r1cs) And(_a, _b frontend.Variable) frontend.Variable { vars, _ := system.toVariables(_a, _b) a := vars[0] @@ -438,17 +403,17 @@ func (system *r1CS) And(_a, _b frontend.Variable) frontend.Variable { // Conditionals // Select if i0 is true, yields i1 else yields i2 -func (system *r1CS) Select(i0, i1, i2 frontend.Variable) frontend.Variable { +func (system *r1cs) Select(i0, i1, i2 frontend.Variable) frontend.Variable { vars, _ := system.toVariables(i0, i1, i2) b := vars[0] // ensures that b is boolean system.AssertIsBoolean(b) + n1, ok1 := system.ConstantValue(vars[1]) + n2, ok2 := system.ConstantValue(vars[2]) - if vars[1].IsConstant() && vars[2].IsConstant() { - n1 := system.constantValue(vars[1]) - n2 := system.constantValue(vars[2]) + if ok1 && ok2 { diff := n1.Sub(n1, n2) res := system.Mul(b, diff) // no constraint is recorded res = system.Add(res, vars[2]) // no constraint is recorded @@ -456,8 +421,7 @@ func (system *r1CS) Select(i0, i1, i2 frontend.Variable) frontend.Variable { } // special case appearing in AssertIsLessOrEq - if vars[1].IsConstant() { - n1 := system.constantValue(vars[1]) + if ok1 { if n1.IsUint64() && n1.Uint64() == 0 { v := system.Sub(1, vars[0]) return system.Mul(v, vars[2]) @@ -473,7 +437,7 @@ func (system *r1CS) Select(i0, i1, i2 frontend.Variable) frontend.Variable { // Lookup2 performs a 2-bit lookup between i1, i2, i3, i4 based on bits b0 // and b1. Returns i0 if b0=b1=0, i1 if b0=1 and b1=0, i2 if b0=0 and b1=1 // and i3 if b0=b1=1. -func (system *r1CS) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 frontend.Variable) frontend.Variable { +func (system *r1cs) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 frontend.Variable) frontend.Variable { vars, _ := system.toVariables(b0, b1, i0, i1, i2, i3) s0, s1 := vars[0], vars[1] in0, in1, in2, in3 := vars[2], vars[3], vars[4], vars[5] @@ -504,16 +468,14 @@ func (system *r1CS) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 frontend.Va } // IsZero returns 1 if i1 is zero, 0 otherwise -func (system *r1CS) IsZero(i1 frontend.Variable) frontend.Variable { +func (system *r1cs) IsZero(i1 frontend.Variable) frontend.Variable { vars, _ := system.toVariables(i1) a := vars[0] - if a.IsConstant() { - // c := a.constantValue(cs) - c := system.constantValue(a) + if c, ok := system.ConstantValue(a); ok { if c.IsUint64() && c.Uint64() == 0 { - return system.constant(1) + return system.toVariable(1) } - return system.constant(0) + return system.toVariable(0) } debug := system.AddDebugInfo("isZero", a) @@ -523,13 +485,13 @@ func (system *r1CS) IsZero(i1 frontend.Variable) frontend.Variable { // _ = inverse(m + a) // constrain m to be 1 if a == 0 // m is computed by the solver such that m = 1 - a^(modulus - 1) - res, err := system.NewHint(hint.IsZero, a) + res, err := system.NewHint(hint.IsZero, 1, a) if err != nil { // the function errs only if the number of inputs is invalid. panic(err) } m := res[0] - system.addConstraint(newR1C(a, m, system.constant(0)), debug) + system.addConstraint(newR1C(a, m, system.toVariable(0)), debug) system.AssertIsBoolean(m) ma := system.Add(m, a) @@ -538,13 +500,13 @@ func (system *r1CS) IsZero(i1 frontend.Variable) frontend.Variable { } // Cmp returns 1 if i1>i2, 0 if i1=i2, -1 if i1= 0; i-- { @@ -563,40 +525,12 @@ func (system *r1CS) Cmp(i1, i2 frontend.Variable) frontend.Variable { return res } -// --------------------------------------------------------------------------------------------- -// Assertions - -// IsConstant returns true if v is a constant known at compile time -func (system *r1CS) IsConstant(v frontend.Variable) bool { - if _v, ok := v.(compiled.Variable); ok { - return _v.IsConstant() - } - // it's not a wire, it's another golang type, we consider it constant. - // TODO we may want to use the struct parser to ensure this frontend.Variable interface doesn't contain fields which are - // frontend.Variable - return true -} - -// ConstantValue returns the big.Int value of v. -// Will panic if v.IsConstant() == false -func (system *r1CS) ConstantValue(v frontend.Variable) *big.Int { - if _v, ok := v.(compiled.Variable); ok { - return system.constantValue(_v) - } - r := utils.FromInterface(v) - return &r -} - -func (system *r1CS) Backend() backend.ID { - return backend.GROTH16 -} - // Println enables circuit debugging and behaves almost like fmt.Println() // // the print will be done once the R1CS.Solve() method is executed // // if one of the input is a variable, its value will be resolved avec R1CS.Solve() method is called -func (system *r1CS) Println(a ...frontend.Variable) { +func (system *r1cs) Println(a ...frontend.Variable) { var sbb strings.Builder // prefix log line with file.go:line @@ -613,14 +547,14 @@ func (system *r1CS) Println(a ...frontend.Variable) { if i > 0 { sbb.WriteByte(' ') } - if v, ok := arg.(compiled.Variable); ok { - v.AssertIsSet() + if v, ok := arg.(compiled.LinearExpression); ok { + assertIsSet(v) sbb.WriteString("%s") // we set limits to the linear expression, so that the log printer // can evaluate it before printing it log.ToResolve = append(log.ToResolve, compiled.TermDelimitor) - log.ToResolve = append(log.ToResolve, v.LinExp...) + log.ToResolve = append(log.ToResolve, v...) log.ToResolve = append(log.ToResolve, compiled.TermDelimitor) } else { printArg(&log, &sbb, arg) @@ -660,11 +594,11 @@ func printArg(log *compiled.LogEntry, sbb *strings.Builder, a frontend.Variable) sbb.WriteString(", ") } - v := tValue.Interface().(compiled.Variable) + v := tValue.Interface().(compiled.LinearExpression) // we set limits to the linear expression, so that the log printer // can evaluate it before printing it log.ToResolve = append(log.ToResolve, compiled.TermDelimitor) - log.ToResolve = append(log.ToResolve, v.LinExp...) + log.ToResolve = append(log.ToResolve, v...) log.ToResolve = append(log.ToResolve, compiled.TermDelimitor) return nil } @@ -673,132 +607,19 @@ func printArg(log *compiled.LogEntry, sbb *strings.Builder, a frontend.Variable) sbb.WriteByte('}') } -// Tag creates a tag at a given place in a circuit. The state of the tag may contain informations needed to -// measure constraints, variables and coefficients creations through AddCounter -func (system *r1CS) Tag(name string) frontend.Tag { - _, file, line, _ := runtime.Caller(1) - - return frontend.Tag{ - Name: fmt.Sprintf("%s[%s:%d]", name, filepath.Base(file), line), - VID: system.NbInternalVariables, - CID: len(system.Constraints), - } -} - -// AddCounter measures the number of constraints, variables and coefficients created between two tags -func (system *r1CS) AddCounter(from, to frontend.Tag) { - system.Counters = append(system.Counters, compiled.Counter{ - From: from.Name, - To: to.Name, - NbVariables: to.VID - from.VID, - NbConstraints: to.CID - from.CID, - CurveID: system.CurveID, - BackendID: backend.GROTH16, - }) -} - -// NewHint initializes internal variables whose value will be evaluated using -// the provided hint function at run time from the inputs. Inputs must be either -// variables or convertible to *big.Int. The function returns an error if the -// number of inputs is not compatible with f. -// -// The hint function is provided at the proof creation time and is not embedded -// into the circuit. From the backend point of view, the variable returned by -// the hint function is equivalent to the user-supplied witness, but its actual -// value is assigned by the solver, not the caller. -// -// No new constraints are added to the newly created wire and must be added -// manually in the circuit. Failing to do so leads to solver failure. -func (system *r1CS) NewHint(f hint.Function, inputs ...frontend.Variable) ([]frontend.Variable, error) { - - if f.NbOutputs(system.Curve(), len(inputs)) <= 0 { - return nil, fmt.Errorf("hint function must return at least one output") - } - hintInputs := make([]interface{}, len(inputs)) - - // ensure inputs are set and pack them in a []uint64 - for i, in := range inputs { - switch t := in.(type) { - case compiled.Variable: - tmp := t.Clone() - hintInputs[i] = tmp.LinExp - case compiled.LinearExpression: - tmp := make(compiled.LinearExpression, len(t)) - copy(tmp, t) - hintInputs[i] = tmp - default: - hintInputs[i] = utils.FromInterface(t) - } - } - - // prepare wires - varIDs := make([]int, f.NbOutputs(system.Curve(), len(inputs))) - res := make([]frontend.Variable, len(varIDs)) - for i := range varIDs { - r := system.newInternalVariable() - _, vID, _ := r.LinExp[0].Unpack() - varIDs[i] = vID - res[i] = r - } - - ch := &compiled.Hint{ID: f.UUID(), Inputs: hintInputs, Wires: varIDs} - for _, vID := range varIDs { - system.MHints[vID] = ch - } - - return res, nil -} - -// constant will return (and allocate if neccesary) a frontend.Variable from given value -// -// if input is already a frontend.Variable, does nothing -// else, attempts to convert input to a big.Int (see utils.FromInterface) and returns a constant frontend.Variable -// -// a constant frontend.Variable does NOT necessary allocate a frontend.Variable in the ConstraintSystem -// it is in the form ONE_WIRE * coeff -func (system *r1CS) constant(input frontend.Variable) frontend.Variable { - - switch t := input.(type) { - case compiled.Variable: - t.AssertIsSet() - return t - default: - n := utils.FromInterface(t) - if n.IsUint64() && n.Uint64() == 1 { - return system.one() - } - r := system.one() - r.LinExp[0] = system.setCoeff(r.LinExp[0], &n) - return r - } -} - -// toVariables return frontend.Variable corresponding to inputs and the total size of the linear expressions -func (system *r1CS) toVariables(in ...frontend.Variable) ([]compiled.Variable, int) { - r := make([]compiled.Variable, 0, len(in)) - s := 0 - e := func(i frontend.Variable) { - v := system.constant(i).(compiled.Variable) - r = append(r, v) - s += len(v.LinExp) - } - // e(i1) - // e(i2) - for i := 0; i < len(in); i++ { - e(in[i]) - } - return r, s -} - // returns -le, the result is a copy -func (system *r1CS) negateLinExp(l []compiled.Term) []compiled.Term { - res := make([]compiled.Term, len(l)) +func (system *r1cs) negateLinExp(l compiled.LinearExpression) compiled.LinearExpression { + res := make(compiled.LinearExpression, len(l)) var lambda big.Int for i, t := range l { cID, vID, visibility := t.Unpack() - lambda.Neg(&system.Coeffs[cID]) - cID = system.CoeffID(&lambda) + lambda.Neg(&system.st.Coeffs[cID]) + cID = system.st.CoeffID(&lambda) res[i] = compiled.Pack(vID, cID, visibility) } return res } + +func (system *r1cs) Compiler() frontend.Compiler { + return system +} diff --git a/frontend/cs/r1cs/assertions.go b/frontend/cs/r1cs/api_assertions.go similarity index 76% rename from frontend/cs/r1cs/assertions.go rename to frontend/cs/r1cs/api_assertions.go index aba87440c0..4e0137057b 100644 --- a/frontend/cs/r1cs/assertions.go +++ b/frontend/cs/r1cs/api_assertions.go @@ -21,15 +21,15 @@ import ( "math/big" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/internal/backend/compiled" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/internal/utils" ) // AssertIsEqual adds an assertion in the constraint system (i1 == i2) -func (system *r1CS) AssertIsEqual(i1, i2 frontend.Variable) { +func (system *r1cs) AssertIsEqual(i1, i2 frontend.Variable) { // encoded 1 * i1 == i2 - r := system.constant(i1).(compiled.Variable) - o := system.constant(i2).(compiled.Variable) + r := system.toVariable(i1).(compiled.LinearExpression) + o := system.toVariable(i2).(compiled.LinearExpression) debug := system.AddDebugInfo("assertIsEqual", r, " == ", o) @@ -37,32 +37,31 @@ func (system *r1CS) AssertIsEqual(i1, i2 frontend.Variable) { } // AssertIsDifferent constrain i1 and i2 to be different -func (system *r1CS) AssertIsDifferent(i1, i2 frontend.Variable) { +func (system *r1cs) AssertIsDifferent(i1, i2 frontend.Variable) { system.Inverse(system.Sub(i1, i2)) } // AssertIsBoolean adds an assertion in the constraint system (v == 0 ∥ v == 1) -func (system *r1CS) AssertIsBoolean(i1 frontend.Variable) { +func (system *r1cs) AssertIsBoolean(i1 frontend.Variable) { vars, _ := system.toVariables(i1) v := vars[0] - if *v.IsBoolean { - return // compiled.Variable is already constrained - } - *v.IsBoolean = true - - if v.IsConstant() { - c := system.constantValue(v) + if c, ok := system.ConstantValue(v); ok { if !(c.IsUint64() && (c.Uint64() == 0 || c.Uint64() == 1)) { panic(fmt.Sprintf("assertIsBoolean failed: constant(%s)", c.String())) } return } + if system.IsBoolean(v) { + return // compiled.LinearExpression is already constrained + } + system.MarkBoolean(v) + debug := system.AddDebugInfo("assertIsBoolean", v, " == (0|1)") - o := system.constant(0) + o := system.toVariable(0) // ensure v * (1 - v) == 0 _v := system.Sub(1, v) @@ -75,12 +74,12 @@ func (system *r1CS) AssertIsBoolean(i1 frontend.Variable) { // // derived from: // https://github.com/zcash/zips/blob/main/protocol/protocol.pdf -func (system *r1CS) AssertIsLessOrEqual(_v frontend.Variable, bound frontend.Variable) { +func (system *r1cs) AssertIsLessOrEqual(_v frontend.Variable, bound frontend.Variable) { v, _ := system.toVariables(_v) switch b := bound.(type) { - case compiled.Variable: - b.AssertIsSet() + case compiled.LinearExpression: + assertIsSet(b) system.mustBeLessOrEqVar(v[0], b) default: system.mustBeLessOrEqCst(v[0], utils.FromInterface(b)) @@ -88,7 +87,7 @@ func (system *r1CS) AssertIsLessOrEqual(_v frontend.Variable, bound frontend.Var } -func (system *r1CS) mustBeLessOrEqVar(a, bound compiled.Variable) { +func (system *r1cs) mustBeLessOrEqVar(a, bound compiled.LinearExpression) { debug := system.AddDebugInfo("mustBeLessOrEq", a, " <= ", bound) nbBits := system.BitLen() @@ -97,9 +96,9 @@ func (system *r1CS) mustBeLessOrEqVar(a, bound compiled.Variable) { boundBits := system.ToBinary(bound, nbBits) p := make([]frontend.Variable, nbBits+1) - p[nbBits] = system.constant(1) + p[nbBits] = system.toVariable(1) - zero := system.constant(0) + zero := system.toVariable(0) for i := nbBits - 1; i >= 0; i-- { @@ -122,14 +121,14 @@ func (system *r1CS) mustBeLessOrEqVar(a, bound compiled.Variable) { // note if bound[i] == 1, this constraint is (1 - ai) * ai == 0 // → this is a boolean constraint // if bound[i] == 0, t must be 0 or 1, thus ai must be 0 or 1 too - system.markBoolean(aBits[i].(compiled.Variable)) // this does not create a constraint + system.MarkBoolean(aBits[i].(compiled.LinearExpression)) // this does not create a constraint system.addConstraint(newR1C(l, aBits[i], zero), debug) } } -func (system *r1CS) mustBeLessOrEqCst(a compiled.Variable, bound big.Int) { +func (system *r1cs) mustBeLessOrEqCst(a compiled.LinearExpression, bound big.Int) { nbBits := system.BitLen() @@ -142,7 +141,7 @@ func (system *r1CS) mustBeLessOrEqCst(a compiled.Variable, bound big.Int) { } // debug info - debug := system.AddDebugInfo("mustBeLessOrEq", a, " <= ", system.constant(bound)) + debug := system.AddDebugInfo("mustBeLessOrEq", a, " <= ", system.toVariable(bound)) // note that at this stage, we didn't boolean-constraint these new variables yet // (as opposed to ToBinary) @@ -159,7 +158,7 @@ func (system *r1CS) mustBeLessOrEqCst(a compiled.Variable, bound big.Int) { p := make([]frontend.Variable, nbBits+1) // p[i] == 1 → a[j] == c[j] for all j ⩾ i - p[nbBits] = system.constant(1) + p[nbBits] = system.toVariable(1) for i := nbBits - 1; i >= t; i-- { if bound.Bit(i) == 0 { @@ -175,8 +174,8 @@ func (system *r1CS) mustBeLessOrEqCst(a compiled.Variable, bound big.Int) { l := system.Sub(1, p[i+1]) l = system.Sub(l, aBits[i]) - system.addConstraint(newR1C(l, aBits[i], system.constant(0)), debug) - system.markBoolean(aBits[i].(compiled.Variable)) + system.addConstraint(newR1C(l, aBits[i], system.toVariable(0)), debug) + system.MarkBoolean(aBits[i].(compiled.LinearExpression)) } else { system.AssertIsBoolean(aBits[i]) } diff --git a/frontend/cs/r1cs/compiler.go b/frontend/cs/r1cs/compiler.go new file mode 100644 index 0000000000..47723617a6 --- /dev/null +++ b/frontend/cs/r1cs/compiler.go @@ -0,0 +1,716 @@ +/* +Copyright © 2020 ConsenSys + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package r1cs + +import ( + "errors" + "fmt" + "math/big" + "path/filepath" + "reflect" + "runtime" + "sort" + "strconv" + "strings" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/compiled" + "github.com/consensys/gnark/frontend/cs" + "github.com/consensys/gnark/frontend/schema" + bls12377r1cs "github.com/consensys/gnark/internal/backend/bls12-377/cs" + bls12381r1cs "github.com/consensys/gnark/internal/backend/bls12-381/cs" + bls24315r1cs "github.com/consensys/gnark/internal/backend/bls24-315/cs" + bn254r1cs "github.com/consensys/gnark/internal/backend/bn254/cs" + bw6633r1cs "github.com/consensys/gnark/internal/backend/bw6-633/cs" + bw6761r1cs "github.com/consensys/gnark/internal/backend/bw6-761/cs" + "github.com/consensys/gnark/internal/utils" +) + +// NewCompiler returns a new R1CS compiler +func NewCompiler(curve ecc.ID, config frontend.CompileConfig) (frontend.Builder, error) { + return newCompiler(curve, config), nil +} + +type r1cs struct { + compiled.ConstraintSystem + Constraints []compiled.R1C + + st cs.CoeffTable + config frontend.CompileConfig + + // map for recording boolean constrained variables (to not constrain them twice) + mtBooleans map[uint64][]compiled.LinearExpression +} + +// initialCapacity has quite some impact on frontend performance, especially on large circuits size +// we may want to add build tags to tune that +func newCompiler(curveID ecc.ID, config frontend.CompileConfig) *r1cs { + system := r1cs{ + ConstraintSystem: compiled.ConstraintSystem{ + + MDebug: make(map[int]int), + MHints: make(map[int]*compiled.Hint), + }, + Constraints: make([]compiled.R1C, 0, config.Capacity), + st: cs.NewCoeffTable(), + mtBooleans: make(map[uint64][]compiled.LinearExpression), + config: config, + } + + system.Public = make([]string, 1) + system.Secret = make([]string, 0) + + // by default the circuit is given a public wire equal to 1 + system.Public[0] = "one" + + system.CurveID = curveID + + return &system +} + +// newInternalVariable creates a new wire, appends it on the list of wires of the circuit, sets +// the wire's id to the number of wires, and returns it +func (system *r1cs) newInternalVariable() compiled.LinearExpression { + idx := system.NbInternalVariables + system.NbInternalVariables++ + return compiled.LinearExpression{ + compiled.Pack(idx, compiled.CoeffIdOne, schema.Internal), + } +} + +// AddPublicVariable creates a new public Variable +func (system *r1cs) AddPublicVariable(name string) frontend.Variable { + if system.Schema != nil { + panic("do not call AddPublicVariable in circuit.Define()") + } + idx := len(system.Public) + system.Public = append(system.Public, name) + return compiled.LinearExpression{ + compiled.Pack(idx, compiled.CoeffIdOne, schema.Public), + } +} + +// AddSecretVariable creates a new secret Variable +func (system *r1cs) AddSecretVariable(name string) frontend.Variable { + if system.Schema != nil { + panic("do not call AddSecretVariable in circuit.Define()") + } + idx := len(system.Secret) + system.Secret = append(system.Secret, name) + return compiled.LinearExpression{ + compiled.Pack(idx, compiled.CoeffIdOne, schema.Secret), + } +} + +func (system *r1cs) one() compiled.LinearExpression { + return compiled.LinearExpression{ + compiled.Pack(0, compiled.CoeffIdOne, schema.Public), + } +} + +// reduces redundancy in linear expression +// It factorizes Variable that appears multiple times with != coeff Ids +// To ensure the determinism in the compile process, Variables are stored as public∥secret∥internal∥unset +// for each visibility, the Variables are sorted from lowest ID to highest ID +func (system *r1cs) reduce(l compiled.LinearExpression) compiled.LinearExpression { + // ensure our linear expression is sorted, by visibility and by Variable ID + if !sort.IsSorted(l) { // may not help + sort.Sort(l) + } + + mod := system.CurveID.Info().Fr.Modulus() + c := new(big.Int) + for i := 1; i < len(l); i++ { + pcID, pvID, pVis := l[i-1].Unpack() + ccID, cvID, cVis := l[i].Unpack() + if pVis == cVis && pvID == cvID { + // we have redundancy + c.Add(&system.st.Coeffs[pcID], &system.st.Coeffs[ccID]) + c.Mod(c, mod) + l[i-1].SetCoeffID(system.st.CoeffID(c)) + l = append(l[:i], l[i+1:]...) + i-- + } + } + return l +} + +// newR1C clones the linear expression associated with the Variables (to avoid offseting the ID multiple time) +// and return a R1C +func newR1C(_l, _r, _o frontend.Variable) compiled.R1C { + l := _l.(compiled.LinearExpression) + r := _r.(compiled.LinearExpression) + o := _o.(compiled.LinearExpression) + + // interestingly, this is key to groth16 performance. + // l * r == r * l == o + // but the "l" linear expression is going to end up in the A matrix + // the "r" linear expression is going to end up in the B matrix + // the less Variable we have appearing in the B matrix, the more likely groth16.Setup + // is going to produce infinity points in pk.G1.B and pk.G2.B, which will speed up proving time + if len(l) > len(r) { + l, r = r, l + } + + return compiled.R1C{L: l.Clone(), R: r.Clone(), O: o.Clone()} +} + +func (system *r1cs) addConstraint(r1c compiled.R1C, debugID ...int) { + system.Constraints = append(system.Constraints, r1c) + if len(debugID) > 0 { + system.MDebug[len(system.Constraints)-1] = debugID[0] + } +} + +// Term packs a Variable and a coeff in a Term and returns it. +// func (system *R1CSRefactor) setCoeff(v Variable, coeff *big.Int) Term { +func (system *r1cs) setCoeff(v compiled.Term, coeff *big.Int) compiled.Term { + _, vID, vVis := v.Unpack() + return compiled.Pack(vID, system.st.CoeffID(coeff), vVis) +} + +// MarkBoolean sets (but do not **constraint**!) v to be boolean +// This is useful in scenarios where a variable is known to be boolean through a constraint +// that is not api.AssertIsBoolean. If v is a constant, this is a no-op. +func (system *r1cs) MarkBoolean(v frontend.Variable) { + if b, ok := system.ConstantValue(v); ok { + if !(b.IsUint64() && b.Uint64() <= 1) { + panic("MarkBoolean called a non-boolean constant") + } + return + } + // v is a linear expression + l := v.(compiled.LinearExpression) + if !sort.IsSorted(l) { + sort.Sort(l) + } + + key := l.HashCode() + list := system.mtBooleans[key] + list = append(list, l) + system.mtBooleans[key] = list +} + +// IsBoolean returns true if given variable was marked as boolean in the compiler (see MarkBoolean) +// Use with care; variable may not have been **constrained** to be boolean +// This returns true if the v is a constant and v == 0 || v == 1. +func (system *r1cs) IsBoolean(v frontend.Variable) bool { + if b, ok := system.ConstantValue(v); ok { + return b.IsUint64() && b.Uint64() <= 1 + } + // v is a linear expression + l := v.(compiled.LinearExpression) + if !sort.IsSorted(l) { + sort.Sort(l) + } + + key := l.HashCode() + list, ok := system.mtBooleans[key] + if !ok { + return false + } + + for _, v := range list { + if v.Equal(l) { + return true + } + } + return false +} + +// checkVariables perform post compilation checks on the Variables +// +// 1. checks that all user inputs are referenced in at least one constraint +// 2. checks that all hints are constrained +func (system *r1cs) checkVariables() error { + + // TODO @gbotrel add unit test for that. + + cptSecret := len(system.Secret) + cptPublic := len(system.Public) + cptHints := len(system.MHints) + + secretConstrained := make([]bool, cptSecret) + publicConstrained := make([]bool, cptPublic) + // one wire does not need to be constrained + publicConstrained[0] = true + cptPublic-- + + mHintsConstrained := make(map[int]bool) + + // for each constraint, we check the linear expressions and mark our inputs / hints as constrained + processLinearExpression := func(l compiled.LinearExpression) { + for _, t := range l { + if t.CoeffID() == compiled.CoeffIdZero { + // ignore zero coefficient, as it does not constraint the Variable + // though, we may want to flag that IF the Variable doesn't appear else where + continue + } + visibility := t.VariableVisibility() + vID := t.WireID() + + switch visibility { + case schema.Public: + if vID != 0 && !publicConstrained[vID] { + publicConstrained[vID] = true + cptPublic-- + } + case schema.Secret: + if !secretConstrained[vID] { + secretConstrained[vID] = true + cptSecret-- + } + case schema.Internal: + if _, ok := system.MHints[vID]; !mHintsConstrained[vID] && ok { + mHintsConstrained[vID] = true + cptHints-- + } + } + } + } + for _, r1c := range system.Constraints { + processLinearExpression(r1c.L) + processLinearExpression(r1c.R) + processLinearExpression(r1c.O) + + if cptHints|cptSecret|cptPublic == 0 { + return nil // we can stop. + } + + } + + // something is a miss, we build the error string + var sbb strings.Builder + if cptSecret != 0 { + sbb.WriteString(strconv.Itoa(cptSecret)) + sbb.WriteString(" unconstrained secret input(s):") + sbb.WriteByte('\n') + for i := 0; i < len(secretConstrained) && cptSecret != 0; i++ { + if !secretConstrained[i] { + sbb.WriteString(system.Secret[i]) + sbb.WriteByte('\n') + cptSecret-- + } + } + sbb.WriteByte('\n') + } + + if cptPublic != 0 { + sbb.WriteString(strconv.Itoa(cptPublic)) + sbb.WriteString(" unconstrained public input(s):") + sbb.WriteByte('\n') + for i := 0; i < len(publicConstrained) && cptPublic != 0; i++ { + if !publicConstrained[i] { + sbb.WriteString(system.Public[i]) + sbb.WriteByte('\n') + cptPublic-- + } + } + sbb.WriteByte('\n') + } + + if cptHints != 0 { + sbb.WriteString(strconv.Itoa(cptHints)) + sbb.WriteString(" unconstrained hints") + sbb.WriteByte('\n') + // TODO we may add more debug info here → idea, in NewHint, take the debug stack, and store in the hint map some + // debugInfo to find where a hint was declared (and not constrained) + } + return errors.New(sbb.String()) +} + +var tVariable reflect.Type + +func init() { + tVariable = reflect.ValueOf(struct{ A frontend.Variable }{}).FieldByName("A").Type() +} + +// Compile constructs a rank-1 constraint sytem +func (cs *r1cs) Compile() (frontend.CompiledConstraintSystem, error) { + + // ensure all inputs and hints are constrained + if !cs.config.IgnoreUnconstrainedInputs { + if err := cs.checkVariables(); err != nil { + return nil, err + } + } + + // wires = public wires | secret wires | internal wires + + // setting up the result + res := compiled.R1CS{ + ConstraintSystem: cs.ConstraintSystem, + Constraints: cs.Constraints, + } + res.NbPublicVariables = len(cs.Public) + res.NbSecretVariables = len(cs.Secret) + + // for Logs, DebugInfo and hints the only thing that will change + // is that ID of the wires will be offseted to take into account the final wire vector ordering + // that is: public wires | secret wires | internal wires + + // offset variable ID depeneding on visibility + shiftVID := func(oldID int, visibility schema.Visibility) int { + switch visibility { + case schema.Internal: + return oldID + res.NbPublicVariables + res.NbSecretVariables + case schema.Public: + return oldID + case schema.Secret: + return oldID + res.NbPublicVariables + } + return oldID + } + + // we just need to offset our ids, such that wires = [ public wires | secret wires | internal wires ] + offsetIDs := func(l compiled.LinearExpression) { + for j := 0; j < len(l); j++ { + _, vID, visibility := l[j].Unpack() + l[j].SetWireID(shiftVID(vID, visibility)) + } + } + + for i := 0; i < len(res.Constraints); i++ { + offsetIDs(res.Constraints[i].L) + offsetIDs(res.Constraints[i].R) + offsetIDs(res.Constraints[i].O) + } + + // we need to offset the ids in the hints + shiftedMap := make(map[int]*compiled.Hint) + + // we need to offset the ids in the hints +HINTLOOP: + for _, hint := range cs.MHints { + ws := make([]int, len(hint.Wires)) + // we set for all outputs in shiftedMap. If one shifted output + // is in shiftedMap, then all are + for i, vID := range hint.Wires { + ws[i] = shiftVID(vID, schema.Internal) + if _, ok := shiftedMap[ws[i]]; i == 0 && ok { + continue HINTLOOP + } + } + inputs := make([]interface{}, len(hint.Inputs)) + copy(inputs, hint.Inputs) + for j := 0; j < len(inputs); j++ { + switch t := inputs[j].(type) { + case compiled.LinearExpression: + tmp := make(compiled.LinearExpression, len(t)) + copy(tmp, t) + offsetIDs(tmp) + inputs[j] = tmp + default: + inputs[j] = t + } + } + ch := &compiled.Hint{ID: hint.ID, Inputs: inputs, Wires: ws} + for _, vID := range ws { + shiftedMap[vID] = ch + } + } + res.MHints = shiftedMap + + // we need to offset the ids in Logs & DebugInfo + for i := 0; i < len(cs.Logs); i++ { + + for j := 0; j < len(res.Logs[i].ToResolve); j++ { + _, vID, visibility := res.Logs[i].ToResolve[j].Unpack() + res.Logs[i].ToResolve[j].SetWireID(shiftVID(vID, visibility)) + } + } + for i := 0; i < len(cs.DebugInfo); i++ { + for j := 0; j < len(res.DebugInfo[i].ToResolve); j++ { + _, vID, visibility := res.DebugInfo[i].ToResolve[j].Unpack() + res.DebugInfo[i].ToResolve[j].SetWireID(shiftVID(vID, visibility)) + } + } + + // build levels + res.Levels = buildLevels(res) + + switch cs.CurveID { + case ecc.BLS12_377: + return bls12377r1cs.NewR1CS(res, cs.st.Coeffs), nil + case ecc.BLS12_381: + return bls12381r1cs.NewR1CS(res, cs.st.Coeffs), nil + case ecc.BN254: + return bn254r1cs.NewR1CS(res, cs.st.Coeffs), nil + case ecc.BW6_761: + return bw6761r1cs.NewR1CS(res, cs.st.Coeffs), nil + case ecc.BW6_633: + return bw6633r1cs.NewR1CS(res, cs.st.Coeffs), nil + case ecc.BLS24_315: + return bls24315r1cs.NewR1CS(res, cs.st.Coeffs), nil + default: + panic("not implemtented") + } +} + +func (cs *r1cs) SetSchema(s *schema.Schema) { + if cs.Schema != nil { + panic("SetSchema called multiple times") + } + cs.Schema = s +} + +func buildLevels(ccs compiled.R1CS) [][]int { + + b := levelBuilder{ + mWireToNode: make(map[int]int, ccs.NbInternalVariables), // at which node we resolved which wire + nodeLevels: make([]int, len(ccs.Constraints)), // level of a node + mLevels: make(map[int]int), // level counts + ccs: ccs, + nbInputs: ccs.NbPublicVariables + ccs.NbSecretVariables, + } + + // for each constraint, we're going to find its direct dependencies + // that is, wires (solved by previous constraints) on which it depends + // each of these dependencies is tagged with a level + // current constraint will be tagged with max(level) + 1 + for cID, c := range ccs.Constraints { + + b.nodeLevel = 0 + + b.processLE(c.L, cID) + b.processLE(c.R, cID) + b.processLE(c.O, cID) + b.nodeLevels[cID] = b.nodeLevel + b.mLevels[b.nodeLevel]++ + + } + + levels := make([][]int, len(b.mLevels)) + for i := 0; i < len(levels); i++ { + // allocate memory + levels[i] = make([]int, 0, b.mLevels[i]) + } + + for n, l := range b.nodeLevels { + levels[l] = append(levels[l], n) + } + + return levels +} + +type levelBuilder struct { + ccs compiled.R1CS + nbInputs int + + mWireToNode map[int]int // at which node we resolved which wire + nodeLevels []int // level per node + mLevels map[int]int // number of constraint per level + + nodeLevel int // current level +} + +func (b *levelBuilder) processLE(l compiled.LinearExpression, cID int) { + + for _, t := range l { + wID := t.WireID() + if wID < b.nbInputs { + // it's a input, we ignore it + continue + } + + // if we know a which constraint solves this wire, then it's a dependency + n, ok := b.mWireToNode[wID] + if ok { + if n != cID { // can happen with hints... + // we add a dependency, check if we need to increment our current level + if b.nodeLevels[n] >= b.nodeLevel { + b.nodeLevel = b.nodeLevels[n] + 1 // we are at the next level at least since we depend on it + } + } + continue + } + + // check if it's a hint and mark all the output wires + if h, ok := b.ccs.MHints[wID]; ok { + + for _, in := range h.Inputs { + switch t := in.(type) { + case compiled.LinearExpression: + b.processLE(t, cID) + case compiled.Term: + b.processLE(compiled.LinearExpression{t}, cID) + } + } + + for _, hwid := range h.Wires { + b.mWireToNode[hwid] = cID + } + continue + } + + // mark this wire solved by current node + b.mWireToNode[wID] = cID + } +} + +// ConstantValue returns the big.Int value of v. +// Will panic if v.IsConstant() == false +func (system *r1cs) ConstantValue(v frontend.Variable) (*big.Int, bool) { + if _v, ok := v.(compiled.LinearExpression); ok { + assertIsSet(_v) + + if len(_v) != 1 { + return nil, false + } + cID, vID, visibility := _v[0].Unpack() + if !(vID == 0 && visibility == schema.Public) { + return nil, false + } + return new(big.Int).Set(&system.st.Coeffs[cID]), true + } + r := utils.FromInterface(v) + return &r, true +} + +func (system *r1cs) Backend() backend.ID { + return backend.GROTH16 +} + +// toVariable will return (and allocate if neccesary) a compiled.LinearExpression from given value +// +// if input is already a compiled.LinearExpression, does nothing +// else, attempts to convert input to a big.Int (see utils.FromInterface) and returns a toVariable compiled.LinearExpression +func (system *r1cs) toVariable(input interface{}) frontend.Variable { + + switch t := input.(type) { + case compiled.LinearExpression: + assertIsSet(t) + return t + default: + n := utils.FromInterface(t) + if n.IsUint64() && n.Uint64() == 1 { + return system.one() + } + r := system.one() + r[0] = system.setCoeff(r[0], &n) + return r + } +} + +// toVariables return frontend.Variable corresponding to inputs and the total size of the linear expressions +func (system *r1cs) toVariables(in ...frontend.Variable) ([]compiled.LinearExpression, int) { + r := make([]compiled.LinearExpression, 0, len(in)) + s := 0 + e := func(i frontend.Variable) { + v := system.toVariable(i).(compiled.LinearExpression) + r = append(r, v) + s += len(v) + } + // e(i1) + // e(i2) + for i := 0; i < len(in); i++ { + e(in[i]) + } + return r, s +} + +// Tag creates a tag at a given place in a circuit. The state of the tag may contain informations needed to +// measure constraints, variables and coefficients creations through AddCounter +func (system *r1cs) Tag(name string) frontend.Tag { + _, file, line, _ := runtime.Caller(1) + + return frontend.Tag{ + Name: fmt.Sprintf("%s[%s:%d]", name, filepath.Base(file), line), + VID: system.NbInternalVariables, + CID: len(system.Constraints), + } +} + +// AddCounter measures the number of constraints, variables and coefficients created between two tags +func (system *r1cs) AddCounter(from, to frontend.Tag) { + system.Counters = append(system.Counters, compiled.Counter{ + From: from.Name, + To: to.Name, + NbVariables: to.VID - from.VID, + NbConstraints: to.CID - from.CID, + CurveID: system.CurveID, + BackendID: backend.GROTH16, + }) +} + +// NewHint initializes internal variables whose value will be evaluated using +// the provided hint function at run time from the inputs. Inputs must be either +// variables or convertible to *big.Int. The function returns an error if the +// number of inputs is not compatible with f. +// +// The hint function is provided at the proof creation time and is not embedded +// into the circuit. From the backend point of view, the variable returned by +// the hint function is equivalent to the user-supplied witness, but its actual +// value is assigned by the solver, not the caller. +// +// No new constraints are added to the newly created wire and must be added +// manually in the circuit. Failing to do so leads to solver failure. +func (system *r1cs) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { + + if nbOutputs <= 0 { + return nil, fmt.Errorf("hint function must return at least one output") + } + hintInputs := make([]interface{}, len(inputs)) + + // ensure inputs are set and pack them in a []uint64 + for i, in := range inputs { + switch t := in.(type) { + case compiled.LinearExpression: + assertIsSet(t) + tmp := make(compiled.LinearExpression, len(t)) + copy(tmp, t) + hintInputs[i] = tmp + default: + hintInputs[i] = utils.FromInterface(t) + } + } + + // prepare wires + varIDs := make([]int, nbOutputs) + res := make([]frontend.Variable, len(varIDs)) + for i := range varIDs { + r := system.newInternalVariable() + _, vID, _ := r[0].Unpack() + varIDs[i] = vID + res[i] = r + } + + ch := &compiled.Hint{ID: f.UUID(), Inputs: hintInputs, Wires: varIDs} + for _, vID := range varIDs { + system.MHints[vID] = ch + } + + return res, nil +} + +// assertIsSet panics if the variable is unset +// this may happen if inside a Define we have +// var a variable +// cs.Mul(a, 1) +// since a was not in the circuit struct it is not a secret variable +func assertIsSet(l compiled.LinearExpression) { + // TODO PlonK scs doesn't have a similar check with compiled.Term == 0 + if len(l) == 0 { + // errNoValue triggered when trying to access a variable that was not allocated + errNoValue := errors.New("can't determine API input value") + panic(errNoValue) + } + +} diff --git a/frontend/cs/r1cs/conversion.go b/frontend/cs/r1cs/conversion.go deleted file mode 100644 index c71d520e9d..0000000000 --- a/frontend/cs/r1cs/conversion.go +++ /dev/null @@ -1,250 +0,0 @@ -/* -Copyright © 2020 ConsenSys - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package r1cs - -import ( - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/schema" - bls12377r1cs "github.com/consensys/gnark/internal/backend/bls12-377/cs" - bls12381r1cs "github.com/consensys/gnark/internal/backend/bls12-381/cs" - bls24315r1cs "github.com/consensys/gnark/internal/backend/bls24-315/cs" - bn254r1cs "github.com/consensys/gnark/internal/backend/bn254/cs" - bw6633r1cs "github.com/consensys/gnark/internal/backend/bw6-633/cs" - bw6761r1cs "github.com/consensys/gnark/internal/backend/bw6-761/cs" - "github.com/consensys/gnark/internal/backend/compiled" -) - -// Compile constructs a rank-1 constraint sytem -func (cs *r1CS) Compile() (frontend.CompiledConstraintSystem, error) { - - // wires = public wires | secret wires | internal wires - - // setting up the result - res := compiled.R1CS{ - CS: cs.CS, - Constraints: cs.Constraints, - } - res.NbPublicVariables = len(cs.Public) - res.NbSecretVariables = len(cs.Secret) - - // for Logs, DebugInfo and hints the only thing that will change - // is that ID of the wires will be offseted to take into account the final wire vector ordering - // that is: public wires | secret wires | internal wires - - // offset variable ID depeneding on visibility - shiftVID := func(oldID int, visibility schema.Visibility) int { - switch visibility { - case schema.Internal: - return oldID + res.NbPublicVariables + res.NbSecretVariables - case schema.Public: - return oldID - case schema.Secret: - return oldID + res.NbPublicVariables - } - return oldID - } - - // we just need to offset our ids, such that wires = [ public wires | secret wires | internal wires ] - offsetIDs := func(l compiled.LinearExpression) { - for j := 0; j < len(l); j++ { - _, vID, visibility := l[j].Unpack() - l[j].SetWireID(shiftVID(vID, visibility)) - } - } - - for i := 0; i < len(res.Constraints); i++ { - offsetIDs(res.Constraints[i].L.LinExp) - offsetIDs(res.Constraints[i].R.LinExp) - offsetIDs(res.Constraints[i].O.LinExp) - } - - // we need to offset the ids in the hints - shiftedMap := make(map[int]*compiled.Hint) - - // we need to offset the ids in the hints -HINTLOOP: - for _, hint := range cs.MHints { - ws := make([]int, len(hint.Wires)) - // we set for all outputs in shiftedMap. If one shifted output - // is in shiftedMap, then all are - for i, vID := range hint.Wires { - ws[i] = shiftVID(vID, schema.Internal) - if _, ok := shiftedMap[ws[i]]; i == 0 && ok { - continue HINTLOOP - } - } - inputs := make([]interface{}, len(hint.Inputs)) - copy(inputs, hint.Inputs) - for j := 0; j < len(inputs); j++ { - switch t := inputs[j].(type) { - case compiled.Variable: - tmp := make(compiled.LinearExpression, len(t.LinExp)) - copy(tmp, t.LinExp) - offsetIDs(tmp) - inputs[j] = tmp - case compiled.LinearExpression: - tmp := make(compiled.LinearExpression, len(t)) - copy(tmp, t) - offsetIDs(tmp) - inputs[j] = tmp - default: - inputs[j] = t - } - } - ch := &compiled.Hint{ID: hint.ID, Inputs: inputs, Wires: ws} - for _, vID := range ws { - shiftedMap[vID] = ch - } - } - res.MHints = shiftedMap - - // we need to offset the ids in Logs & DebugInfo - for i := 0; i < len(cs.Logs); i++ { - - for j := 0; j < len(res.Logs[i].ToResolve); j++ { - _, vID, visibility := res.Logs[i].ToResolve[j].Unpack() - res.Logs[i].ToResolve[j].SetWireID(shiftVID(vID, visibility)) - } - } - for i := 0; i < len(cs.DebugInfo); i++ { - for j := 0; j < len(res.DebugInfo[i].ToResolve); j++ { - _, vID, visibility := res.DebugInfo[i].ToResolve[j].Unpack() - res.DebugInfo[i].ToResolve[j].SetWireID(shiftVID(vID, visibility)) - } - } - - // build levels - res.Levels = buildLevels(res) - - switch cs.CurveID { - case ecc.BLS12_377: - return bls12377r1cs.NewR1CS(res, cs.Coeffs), nil - case ecc.BLS12_381: - return bls12381r1cs.NewR1CS(res, cs.Coeffs), nil - case ecc.BN254: - return bn254r1cs.NewR1CS(res, cs.Coeffs), nil - case ecc.BW6_761: - return bw6761r1cs.NewR1CS(res, cs.Coeffs), nil - case ecc.BW6_633: - return bw6633r1cs.NewR1CS(res, cs.Coeffs), nil - case ecc.BLS24_315: - return bls24315r1cs.NewR1CS(res, cs.Coeffs), nil - default: - panic("not implemtented") - } -} - -func (cs *r1CS) SetSchema(s *schema.Schema) { - cs.Schema = s -} - -func buildLevels(ccs compiled.R1CS) [][]int { - - b := levelBuilder{ - mWireToNode: make(map[int]int, ccs.NbInternalVariables), // at which node we resolved which wire - nodeLevels: make([]int, len(ccs.Constraints)), // level of a node - mLevels: make(map[int]int), // level counts - ccs: ccs, - nbInputs: ccs.NbPublicVariables + ccs.NbSecretVariables, - } - - // for each constraint, we're going to find its direct dependencies - // that is, wires (solved by previous constraints) on which it depends - // each of these dependencies is tagged with a level - // current constraint will be tagged with max(level) + 1 - for cID, c := range ccs.Constraints { - - b.nodeLevel = 0 - - b.processLE(c.L.LinExp, cID) - b.processLE(c.R.LinExp, cID) - b.processLE(c.O.LinExp, cID) - b.nodeLevels[cID] = b.nodeLevel - b.mLevels[b.nodeLevel]++ - - } - - levels := make([][]int, len(b.mLevels)) - for i := 0; i < len(levels); i++ { - // allocate memory - levels[i] = make([]int, 0, b.mLevels[i]) - } - - for n, l := range b.nodeLevels { - levels[l] = append(levels[l], n) - } - - return levels -} - -type levelBuilder struct { - ccs compiled.R1CS - nbInputs int - - mWireToNode map[int]int // at which node we resolved which wire - nodeLevels []int // level per node - mLevels map[int]int // number of constraint per level - - nodeLevel int // current level -} - -func (b *levelBuilder) processLE(l compiled.LinearExpression, cID int) { - - for _, t := range l { - wID := t.WireID() - if wID < b.nbInputs { - // it's a input, we ignore it - continue - } - - // if we know a which constraint solves this wire, then it's a dependency - n, ok := b.mWireToNode[wID] - if ok { - if n != cID { // can happen with hints... - // we add a dependency, check if we need to increment our current level - if b.nodeLevels[n] >= b.nodeLevel { - b.nodeLevel = b.nodeLevels[n] + 1 // we are at the next level at least since we depend on it - } - } - continue - } - - // check if it's a hint and mark all the output wires - if h, ok := b.ccs.MHints[wID]; ok { - - for _, in := range h.Inputs { - switch t := in.(type) { - case compiled.Variable: - b.processLE(t.LinExp, cID) - case compiled.LinearExpression: - b.processLE(t, cID) - case compiled.Term: - b.processLE(compiled.LinearExpression{t}, cID) - } - } - - for _, hwid := range h.Wires { - b.mWireToNode[hwid] = cID - } - continue - } - - // mark this wire solved by current node - b.mWireToNode[wID] = cID - } -} diff --git a/frontend/cs/r1cs/r1cs.go b/frontend/cs/r1cs/r1cs.go deleted file mode 100644 index eb22b890b8..0000000000 --- a/frontend/cs/r1cs/r1cs.go +++ /dev/null @@ -1,324 +0,0 @@ -/* -Copyright © 2021 ConsenSys Software Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package r1cs - -import ( - "errors" - "math/big" - "reflect" - "sort" - "strconv" - "strings" - - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs" - "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" -) - -func NewBuilder(curve ecc.ID) (frontend.Builder, error) { - return newR1CS(curve), nil -} - -type r1CS struct { - cs.ConstraintSystem - - Constraints []compiled.R1C -} - -// initialCapacity has quite some impact on frontend performance, especially on large circuits size -// we may want to add build tags to tune that -func newR1CS(curveID ecc.ID, initialCapacity ...int) *r1CS { - capacity := 0 - if len(initialCapacity) > 0 { - capacity = initialCapacity[0] - } - system := r1CS{ - ConstraintSystem: cs.ConstraintSystem{ - - CS: compiled.CS{ - MDebug: make(map[int]int), - MHints: make(map[int]*compiled.Hint), - }, - - Coeffs: make([]big.Int, 4), - CoeffsIDsLarge: make(map[string]int), - CoeffsIDsInt64: make(map[int64]int, 4), - }, - Constraints: make([]compiled.R1C, 0, capacity), - - // Counters: make([]Counter, 0), - } - - system.Coeffs[compiled.CoeffIdZero].SetInt64(0) - system.Coeffs[compiled.CoeffIdOne].SetInt64(1) - system.Coeffs[compiled.CoeffIdTwo].SetInt64(2) - system.Coeffs[compiled.CoeffIdMinusOne].SetInt64(-1) - - system.CoeffsIDsInt64[0] = compiled.CoeffIdZero - system.CoeffsIDsInt64[1] = compiled.CoeffIdOne - system.CoeffsIDsInt64[2] = compiled.CoeffIdTwo - system.CoeffsIDsInt64[-1] = compiled.CoeffIdMinusOne - - // system.public.variables = make([]Variable, 0) - // system.secret.variables = make([]Variable, 0) - // system.internal = make([]Variable, 0, capacity) - system.Public = make([]string, 1) - system.Secret = make([]string, 0) - - // by default the circuit is given a public wire equal to 1 - system.Public[0] = "one" - - system.CurveID = curveID - // system.BackendID = backendID - - return &system -} - -// newInternalVariable creates a new wire, appends it on the list of wires of the circuit, sets -// the wire's id to the number of wires, and returns it -func (system *r1CS) newInternalVariable() compiled.Variable { - t := false - idx := system.NbInternalVariables - system.NbInternalVariables++ - return compiled.Variable{ - LinExp: compiled.LinearExpression{compiled.Pack(idx, compiled.CoeffIdOne, schema.Internal)}, - IsBoolean: &t, - } -} - -// NewPublicVariable creates a new public Variable -func (system *r1CS) NewPublicVariable(name string) frontend.Variable { - t := false - idx := len(system.Public) - system.Public = append(system.Public, name) - res := compiled.Variable{ - LinExp: compiled.LinearExpression{compiled.Pack(idx, compiled.CoeffIdOne, schema.Public)}, - IsBoolean: &t, - } - return res -} - -// NewSecretVariable creates a new secret Variable -func (system *r1CS) NewSecretVariable(name string) frontend.Variable { - t := false - idx := len(system.Secret) - system.Secret = append(system.Secret, name) - res := compiled.Variable{ - LinExp: compiled.LinearExpression{compiled.Pack(idx, compiled.CoeffIdOne, schema.Secret)}, - IsBoolean: &t, - } - return res -} - -// func (v *variable) constantValue(system *R1CS) *big.Int { -func (system *r1CS) constantValue(v compiled.Variable) *big.Int { - // TODO this might be a good place to start hunting useless allocations. - // maybe through a big.Int pool. - if !v.IsConstant() { - panic("can't get big.Int value on a non-constant variable") - } - return new(big.Int).Set(&system.Coeffs[v.LinExp[0].CoeffID()]) -} - -func (system *r1CS) one() compiled.Variable { - t := false - return compiled.Variable{ - LinExp: compiled.LinearExpression{compiled.Pack(0, compiled.CoeffIdOne, schema.Public)}, - IsBoolean: &t, - } -} - -// reduces redundancy in linear expression -// It factorizes Variable that appears multiple times with != coeff Ids -// To ensure the determinism in the compile process, Variables are stored as public∥secret∥internal∥unset -// for each visibility, the Variables are sorted from lowest ID to highest ID -func (system *r1CS) reduce(l compiled.Variable) compiled.Variable { - // ensure our linear expression is sorted, by visibility and by Variable ID - if !sort.IsSorted(l.LinExp) { // may not help - sort.Sort(l.LinExp) - } - - mod := system.CurveID.Info().Fr.Modulus() - c := new(big.Int) - for i := 1; i < len(l.LinExp); i++ { - pcID, pvID, pVis := l.LinExp[i-1].Unpack() - ccID, cvID, cVis := l.LinExp[i].Unpack() - if pVis == cVis && pvID == cvID { - // we have redundancy - c.Add(&system.Coeffs[pcID], &system.Coeffs[ccID]) - c.Mod(c, mod) - l.LinExp[i-1].SetCoeffID(system.CoeffID(c)) - l.LinExp = append(l.LinExp[:i], l.LinExp[i+1:]...) - i-- - } - } - return l -} - -// newR1C clones the linear expression associated with the Variables (to avoid offseting the ID multiple time) -// and return a R1C -func newR1C(_l, _r, _o frontend.Variable) compiled.R1C { - l := _l.(compiled.Variable) - r := _r.(compiled.Variable) - o := _o.(compiled.Variable) - - // interestingly, this is key to groth16 performance. - // l * r == r * l == o - // but the "l" linear expression is going to end up in the A matrix - // the "r" linear expression is going to end up in the B matrix - // the less Variable we have appearing in the B matrix, the more likely groth16.Setup - // is going to produce infinity points in pk.G1.B and pk.G2.B, which will speed up proving time - if len(l.LinExp) > len(r.LinExp) { - l, r = r, l - } - - return compiled.R1C{L: l.Clone(), R: r.Clone(), O: o.Clone()} -} - -func (system *r1CS) addConstraint(r1c compiled.R1C, debugID ...int) { - system.Constraints = append(system.Constraints, r1c) - if len(debugID) > 0 { - system.MDebug[len(system.Constraints)-1] = debugID[0] - } -} - -// Term packs a Variable and a coeff in a Term and returns it. -// func (system *R1CSRefactor) setCoeff(v Variable, coeff *big.Int) Term { -func (system *r1CS) setCoeff(v compiled.Term, coeff *big.Int) compiled.Term { - _, vID, vVis := v.Unpack() - return compiled.Pack(vID, system.CoeffID(coeff), vVis) -} - -// markBoolean marks the Variable as boolean and return true -// if a constraint was added, false if the Variable was already -// constrained as a boolean -func (system *r1CS) markBoolean(v compiled.Variable) bool { - if *v.IsBoolean { - return false - } - *v.IsBoolean = true - return true -} - -// checkVariables perform post compilation checks on the Variables -// -// 1. checks that all user inputs are referenced in at least one constraint -// 2. checks that all hints are constrained -func (system *r1CS) CheckVariables() error { - - // TODO @gbotrel add unit test for that. - - cptSecret := len(system.Secret) - cptPublic := len(system.Public) - cptHints := len(system.MHints) - - secretConstrained := make([]bool, cptSecret) - publicConstrained := make([]bool, cptPublic) - // one wire does not need to be constrained - publicConstrained[0] = true - cptPublic-- - - mHintsConstrained := make(map[int]bool) - - // for each constraint, we check the linear expressions and mark our inputs / hints as constrained - processLinearExpression := func(l compiled.Variable) { - for _, t := range l.LinExp { - if t.CoeffID() == compiled.CoeffIdZero { - // ignore zero coefficient, as it does not constraint the Variable - // though, we may want to flag that IF the Variable doesn't appear else where - continue - } - visibility := t.VariableVisibility() - vID := t.WireID() - - switch visibility { - case schema.Public: - if vID != 0 && !publicConstrained[vID] { - publicConstrained[vID] = true - cptPublic-- - } - case schema.Secret: - if !secretConstrained[vID] { - secretConstrained[vID] = true - cptSecret-- - } - case schema.Internal: - if _, ok := system.MHints[vID]; !mHintsConstrained[vID] && ok { - mHintsConstrained[vID] = true - cptHints-- - } - } - } - } - for _, r1c := range system.Constraints { - processLinearExpression(r1c.L) - processLinearExpression(r1c.R) - processLinearExpression(r1c.O) - - if cptHints|cptSecret|cptPublic == 0 { - return nil // we can stop. - } - - } - - // something is a miss, we build the error string - var sbb strings.Builder - if cptSecret != 0 { - sbb.WriteString(strconv.Itoa(cptSecret)) - sbb.WriteString(" unconstrained secret input(s):") - sbb.WriteByte('\n') - for i := 0; i < len(secretConstrained) && cptSecret != 0; i++ { - if !secretConstrained[i] { - sbb.WriteString(system.Secret[i]) - sbb.WriteByte('\n') - cptSecret-- - } - } - sbb.WriteByte('\n') - } - - if cptPublic != 0 { - sbb.WriteString(strconv.Itoa(cptPublic)) - sbb.WriteString(" unconstrained public input(s):") - sbb.WriteByte('\n') - for i := 0; i < len(publicConstrained) && cptPublic != 0; i++ { - if !publicConstrained[i] { - sbb.WriteString(system.Public[i]) - sbb.WriteByte('\n') - cptPublic-- - } - } - sbb.WriteByte('\n') - } - - if cptHints != 0 { - sbb.WriteString(strconv.Itoa(cptHints)) - sbb.WriteString(" unconstrained hints") - sbb.WriteByte('\n') - // TODO we may add more debug info here → idea, in NewHint, take the debug stack, and store in the hint map some - // debugInfo to find where a hint was declared (and not constrained) - } - return errors.New(sbb.String()) -} - -var tVariable reflect.Type - -func init() { - tVariable = reflect.ValueOf(struct{ A frontend.Variable }{}).FieldByName("A").Type() -} diff --git a/frontend/cs/r1cs/r1cs_test.go b/frontend/cs/r1cs/r1cs_test.go index 573bb2a0d9..97c8a2f6d9 100644 --- a/frontend/cs/r1cs/r1cs_test.go +++ b/frontend/cs/r1cs/r1cs_test.go @@ -21,8 +21,9 @@ import ( "testing" "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" ) func TestQuickSort(t *testing.T) { @@ -50,7 +51,7 @@ func TestQuickSort(t *testing.T) { func TestReduce(t *testing.T) { - cs := newR1CS(ecc.BN254) + cs := newCompiler(ecc.BN254, frontend.CompileConfig{}) x := cs.newInternalVariable() y := cs.newInternalVariable() z := cs.newInternalVariable() @@ -62,10 +63,10 @@ func TestReduce(t *testing.T) { e := cs.Mul(z, 2) f := cs.Mul(z, 2) - toTest := (cs.Add(a, b, c, d, e, f)).(compiled.Variable) + toTest := (cs.Add(a, b, c, d, e, f)).(compiled.LinearExpression) // check sizes - if len(toTest.LinExp) != 3 { + if len(toTest) != 3 { t.Fatal("Error reduce, duplicate variables not collapsed") } diff --git a/frontend/cs/plonk/api.go b/frontend/cs/scs/api.go similarity index 56% rename from frontend/cs/plonk/api.go rename to frontend/cs/scs/api.go index 117dc7744a..5a234cbded 100644 --- a/frontend/cs/plonk/api.go +++ b/frontend/cs/scs/api.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package plonk +package scs import ( "fmt" @@ -25,16 +25,14 @@ import ( "strconv" "strings" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" - "github.com/consensys/gnark/internal/utils" ) // Add returns res = i1+i2+...in -func (system *sparseR1CS) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { +func (system *scs) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { zero := big.NewInt(0) vars, k := system.filterConstantSum(append([]frontend.Variable{i1, i2}, in...)) @@ -46,7 +44,7 @@ func (system *sparseR1CS) Add(i1, i2 frontend.Variable, in ...frontend.Variable) return system.splitSum(vars[0], vars[1:]) } cl, _, _ := vars[0].Unpack() - kID := system.CoeffID(&k) + kID := system.st.CoeffID(&k) o := system.newInternalVariable() system.addPlonkConstraint(vars[0], system.zero(), o, cl, compiled.CoeffIdZero, compiled.CoeffIdZero, compiled.CoeffIdZero, compiled.CoeffIdMinusOne, kID) return system.splitSum(o, vars[1:]) @@ -54,7 +52,7 @@ func (system *sparseR1CS) Add(i1, i2 frontend.Variable, in ...frontend.Variable) } // neg returns -in -func (system *sparseR1CS) neg(in []frontend.Variable) []frontend.Variable { +func (system *scs) neg(in []frontend.Variable) []frontend.Variable { res := make([]frontend.Variable, len(in)) @@ -65,31 +63,31 @@ func (system *sparseR1CS) neg(in []frontend.Variable) []frontend.Variable { } // Sub returns res = i1 - i2 - ...in -func (system *sparseR1CS) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { +func (system *scs) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { r := system.neg(append([]frontend.Variable{i2}, in...)) return system.Add(i1, r[0], r[1:]...) } // Neg returns -i -func (system *sparseR1CS) Neg(i1 frontend.Variable) frontend.Variable { - if system.IsConstant(i1) { - k := system.ConstantValue(i1) - k.Neg(k) - return *k +func (system *scs) Neg(i1 frontend.Variable) frontend.Variable { + if n, ok := system.ConstantValue(i1); ok { + n.Neg(n) + // TODO shouldn't that go through variable conversion? + return *n } else { v := i1.(compiled.Term) c, _, _ := v.Unpack() var coef big.Int - coef.Set(&system.Coeffs[c]) + coef.Set(&system.st.Coeffs[c]) coef.Neg(&coef) - c = system.CoeffID(&coef) + c = system.st.CoeffID(&coef) v.SetCoeffID(c) return v } } // Mul returns res = i1 * i2 * ... in -func (system *sparseR1CS) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { +func (system *scs) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { vars, k := system.filterConstantProd(append([]frontend.Variable{i1, i2}, in...)) if len(vars) == 0 { @@ -101,37 +99,38 @@ func (system *sparseR1CS) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) } // returns t*m -func (system *sparseR1CS) mulConstant(t compiled.Term, m *big.Int) compiled.Term { +func (system *scs) mulConstant(t compiled.Term, m *big.Int) compiled.Term { var coef big.Int cid, _, _ := t.Unpack() - coef.Set(&system.Coeffs[cid]) + coef.Set(&system.st.Coeffs[cid]) coef.Mul(m, &coef).Mod(&coef, system.CurveID.Info().Fr.Modulus()) - cid = system.CoeffID(&coef) + cid = system.st.CoeffID(&coef) t.SetCoeffID(cid) return t } // DivUnchecked returns i1 / i2 . if i1 == i2 == 0, returns 0 -func (system *sparseR1CS) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable { +func (system *scs) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable { + c1, i1Constant := system.ConstantValue(i1) + c2, i2Constant := system.ConstantValue(i2) - if system.IsConstant(i1) && system.IsConstant(i2) { - l := utils.FromInterface(i1) - r := utils.FromInterface(i2) + if i1Constant && i2Constant { + l := c1 + r := c2 q := system.CurveID.Info().Fr.Modulus() - return r.ModInverse(&r, q). - Mul(&l, &r). - Mod(&r, q) + return r.ModInverse(r, q). + Mul(l, r). + Mod(r, q) } - if system.IsConstant(i2) { - c := utils.FromInterface(i2) + if i2Constant { + c := c2 m := system.CurveID.Info().Fr.Modulus() - c.ModInverse(&c, m) - return system.mulConstant(i1.(compiled.Term), &c) + c.ModInverse(c, m) + return system.mulConstant(i1.(compiled.Term), c) } - if system.IsConstant(i1) { + if i1Constant { res := system.Inverse(i2) - m := utils.FromInterface(i1) - return system.mulConstant(res.(compiled.Term), &m) + return system.mulConstant(res.(compiled.Term), c1) } res := system.newInternalVariable() @@ -144,7 +143,7 @@ func (system *sparseR1CS) DivUnchecked(i1, i2 frontend.Variable) frontend.Variab } // Div returns i1 / i2 -func (system *sparseR1CS) Div(i1, i2 frontend.Variable) frontend.Variable { +func (system *scs) Div(i1, i2 frontend.Variable) frontend.Variable { // note that here we ensure that v2 can't be 0, but it costs us one extra constraint system.Inverse(i2) @@ -153,10 +152,9 @@ func (system *sparseR1CS) Div(i1, i2 frontend.Variable) frontend.Variable { } // Inverse returns res = 1 / i1 -func (system *sparseR1CS) Inverse(i1 frontend.Variable) frontend.Variable { - if system.IsConstant(i1) { - c := utils.FromInterface(i1) - c.ModInverse(&c, system.CurveID.Info().Fr.Modulus()) +func (system *scs) Inverse(i1 frontend.Variable) frontend.Variable { + if c, ok := system.ConstantValue(i1); ok { + c.ModInverse(c, system.CurveID.Info().Fr.Modulus()) return c } t := i1.(compiled.Term) @@ -175,7 +173,7 @@ func (system *sparseR1CS) Inverse(i1 frontend.Variable) frontend.Variable { // n default value is fr.Bits the number of bits needed to represent a field element // // The result in in little endian (first bit= lsb) -func (system *sparseR1CS) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable { +func (system *scs) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable { // nbBits nbBits := system.BitLen() @@ -187,8 +185,7 @@ func (system *sparseR1CS) ToBinary(i1 frontend.Variable, n ...int) []frontend.Va } // if a is a constant, work with the big int value. - if system.IsConstant(i1) { - c := utils.FromInterface(i1) + if c, ok := system.ConstantValue(i1); ok { b := make([]frontend.Variable, nbBits) for i := 0; i < len(b); i++ { b[i] = c.Bit(i) @@ -200,23 +197,23 @@ func (system *sparseR1CS) ToBinary(i1 frontend.Variable, n ...int) []frontend.Va return system.toBinary(a, nbBits, false) } -func (system *sparseR1CS) toBinary(a compiled.Term, nbBits int, unsafe bool) []frontend.Variable { +func (system *scs) toBinary(a compiled.Term, nbBits int, unsafe bool) []frontend.Variable { // allocate the resulting frontend.Variables and bit-constraint them - b := make([]frontend.Variable, nbBits) sb := make([]frontend.Variable, nbBits) var c big.Int c.SetUint64(1) + + bits, err := system.NewHint(hint.NBits, nbBits, a) + if err != nil { + panic(err) + } + for i := 0; i < nbBits; i++ { - res, err := system.NewHint(hint.IthBit, a, i) - if err != nil { - panic(err) - } - b[i] = res[0] - sb[i] = system.Mul(b[i], c) + sb[i] = system.Mul(bits[i], c) c.Lsh(&c, 1) if !unsafe { - system.AssertIsBoolean(b[i]) + system.AssertIsBoolean(bits[i]) } } @@ -224,7 +221,7 @@ func (system *sparseR1CS) toBinary(a compiled.Term, nbBits int, unsafe bool) []f // TODO we can save a constraint here var Σbi frontend.Variable if nbBits == 1 { - system.AssertIsEqual(sb[0], a) + Σbi = sb[0] } else if nbBits == 2 { Σbi = system.Add(sb[0], sb[1]) } else { @@ -233,12 +230,12 @@ func (system *sparseR1CS) toBinary(a compiled.Term, nbBits int, unsafe bool) []f system.AssertIsEqual(Σbi, a) // record the constraint Σ (2**i * b[i]) == a - return b + return bits } // FromBinary packs b, seen as a fr.Element in little endian -func (system *sparseR1CS) FromBinary(b ...frontend.Variable) frontend.Variable { +func (system *scs) FromBinary(b ...frontend.Variable) frontend.Variable { _b := make([]frontend.Variable, len(b)) var c big.Int c.SetUint64(1) @@ -257,24 +254,26 @@ func (system *sparseR1CS) FromBinary(b ...frontend.Variable) frontend.Variable { // Xor returns a ^ b // a and b must be 0 or 1 -func (system *sparseR1CS) Xor(a, b frontend.Variable) frontend.Variable { - if system.IsConstant(a) && system.IsConstant(b) { - _a := utils.FromInterface(a) - _b := utils.FromInterface(b) - _a.Xor(&_a, &_b) +func (system *scs) Xor(a, b frontend.Variable) frontend.Variable { + _a, aConstant := system.ConstantValue(a) + _b, bConstant := system.ConstantValue(b) + + if aConstant && bConstant { + _a.Xor(_a, _b) return _a } res := system.newInternalVariable() - if system.IsConstant(a) { + if aConstant { a, b = b, a + bConstant = aConstant + _b = _a } - if system.IsConstant(b) { + if bConstant { l := a.(compiled.Term) r := l - _b := utils.FromInterface(b) one := big.NewInt(1) - _b.Lsh(&_b, 1).Sub(&_b, one) - idl := system.CoeffID(&_b) + _b.Lsh(_b, 1).Sub(_b, one) + idl := system.st.CoeffID(_b) system.addPlonkConstraint(l, r, res, idl, compiled.CoeffIdZero, compiled.CoeffIdZero, compiled.CoeffIdZero, compiled.CoeffIdOne, compiled.CoeffIdZero) return res } @@ -286,35 +285,32 @@ func (system *sparseR1CS) Xor(a, b frontend.Variable) frontend.Variable { // Or returns a | b // a and b must be 0 or 1 -func (system *sparseR1CS) Or(a, b frontend.Variable) frontend.Variable { - - var zero, one big.Int - one.SetUint64(1) +func (system *scs) Or(a, b frontend.Variable) frontend.Variable { + _a, aConstant := system.ConstantValue(a) + _b, bConstant := system.ConstantValue(b) - if system.IsConstant(a) && system.IsConstant(b) { - _a := utils.FromInterface(a) - _b := utils.FromInterface(b) - _a.Or(&_a, &_b) + if aConstant && bConstant { + _a.Or(_a, _b) return _a } res := system.newInternalVariable() - if system.IsConstant(a) { + if aConstant { a, b = b, a + _b = _a + bConstant = aConstant } - if system.IsConstant(b) { - _b := utils.FromInterface(b) - + if bConstant { l := a.(compiled.Term) r := l - if _b.Cmp(&one) != 0 && _b.Cmp(&zero) != 0 { + if !(_b.IsUint64() && (_b.Uint64() <= 1)) { panic(fmt.Sprintf("%s should be 0 or 1", _b.String())) } system.AssertIsBoolean(a) one := big.NewInt(1) - _b.Sub(&_b, one) - idl := system.CoeffID(&_b) + _b.Sub(_b, one) + idl := system.st.CoeffID(_b) system.addPlonkConstraint(l, r, res, idl, compiled.CoeffIdZero, compiled.CoeffIdZero, compiled.CoeffIdZero, compiled.CoeffIdOne, compiled.CoeffIdZero) return res } @@ -328,7 +324,7 @@ func (system *sparseR1CS) Or(a, b frontend.Variable) frontend.Variable { // Or returns a & b // a and b must be 0 or 1 -func (system *sparseR1CS) And(a, b frontend.Variable) frontend.Variable { +func (system *scs) And(a, b frontend.Variable) frontend.Variable { system.AssertIsBoolean(a) system.AssertIsBoolean(b) return system.Mul(a, b) @@ -338,16 +334,14 @@ func (system *sparseR1CS) And(a, b frontend.Variable) frontend.Variable { // Conditionals // Select if b is true, yields i1 else yields i2 -func (system *sparseR1CS) Select(b frontend.Variable, i1, i2 frontend.Variable) frontend.Variable { +func (system *scs) Select(b frontend.Variable, i1, i2 frontend.Variable) frontend.Variable { + _b, bConstant := system.ConstantValue(b) - if system.IsConstant(b) { - _b := utils.FromInterface(b) - var t big.Int - one := big.NewInt(1) - if _b.Cmp(&t) != 0 && _b.Cmp(one) != 0 { - panic("b should be a boolean") + if bConstant { + if !(_b.IsUint64() && (_b.Uint64() <= 1)) { + panic(fmt.Sprintf("%s should be 0 or 1", _b.String())) } - if _b.Cmp(&t) == 0 { + if _b.Uint64() == 0 { return i2 } return i1 @@ -362,7 +356,7 @@ func (system *sparseR1CS) Select(b frontend.Variable, i1, i2 frontend.Variable) // Lookup2 performs a 2-bit lookup between i1, i2, i3, i4 based on bits b0 // and b1. Returns i0 if b0=b1=0, i1 if b0=1 and b1=0, i2 if b0=0 and b1=1 // and i3 if b0=b1=1. -func (system *sparseR1CS) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 frontend.Variable) frontend.Variable { +func (system *scs) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 frontend.Variable) frontend.Variable { // vars, _ := system.toVariables(b0, b1, i0, i1, i2, i3) // s0, s1 := vars[0], vars[1] @@ -398,12 +392,9 @@ func (system *sparseR1CS) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 front } // IsZero returns 1 if a is zero, 0 otherwise -func (system *sparseR1CS) IsZero(i1 frontend.Variable) frontend.Variable { - - if system.IsConstant(i1) { - a := utils.FromInterface(i1) - var zero big.Int - if a.Cmp(&zero) != 0 { +func (system *scs) IsZero(i1 frontend.Variable) frontend.Variable { + if a, ok := system.ConstantValue(i1); ok { + if !(a.IsUint64() && a.Uint64() == 0) { panic("input should be zero") } return 1 @@ -413,7 +404,7 @@ func (system *sparseR1CS) IsZero(i1 frontend.Variable) frontend.Variable { // a * m = 0 // constrain m to be 0 if a != 0 // _ = inverse(m + a) // constrain m to be 1 if a == 0 a := i1.(compiled.Term) - res, err := system.NewHint(hint.IsZero, a) + res, err := system.NewHint(hint.IsZero, 1, a) if err != nil { // the function errs only if the number of inputs is invalid. panic(err) @@ -427,7 +418,7 @@ func (system *sparseR1CS) IsZero(i1 frontend.Variable) frontend.Variable { } // Cmp returns 1 if i1>i2, 0 if i1=i2, -1 if i1 bound -func (system *sparseR1CS) AssertIsLessOrEqual(v frontend.Variable, bound frontend.Variable) { +func (system *scs) AssertIsLessOrEqual(v frontend.Variable, bound frontend.Variable) { switch b := bound.(type) { case compiled.Term: system.mustBeLessOrEqVar(v.(compiled.Term), b) @@ -96,7 +98,7 @@ func (system *sparseR1CS) AssertIsLessOrEqual(v frontend.Variable, bound fronten } } -func (system *sparseR1CS) mustBeLessOrEqVar(a compiled.Term, bound compiled.Term) { +func (system *scs) mustBeLessOrEqVar(a compiled.Term, bound compiled.Term) { debug := system.AddDebugInfo("mustBeLessOrEq", a, " <= ", bound) @@ -127,7 +129,7 @@ func (system *sparseR1CS) mustBeLessOrEqVar(a compiled.Term, bound compiled.Term // note if bound[i] == 1, this constraint is (1 - ai) * ai == 0 // → this is a boolean constraint // if bound[i] == 0, t must be 0 or 1, thus ai must be 0 or 1 too - system.markBoolean(aBits[i].(compiled.Term)) // this does not create a constraint + system.MarkBoolean(aBits[i].(compiled.Term)) // this does not create a constraint system.addPlonkConstraint( l.(compiled.Term), @@ -143,7 +145,7 @@ func (system *sparseR1CS) mustBeLessOrEqVar(a compiled.Term, bound compiled.Term } -func (system *sparseR1CS) mustBeLessOrEqCst(a compiled.Term, bound big.Int) { +func (system *scs) mustBeLessOrEqCst(a compiled.Term, bound big.Int) { nbBits := system.BitLen() diff --git a/frontend/cs/scs/compiler.go b/frontend/cs/scs/compiler.go new file mode 100644 index 0000000000..fc7e9882d8 --- /dev/null +++ b/frontend/cs/scs/compiler.go @@ -0,0 +1,665 @@ +/* +Copyright © 2021 ConsenSys Software Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package scs + +import ( + "errors" + "fmt" + "math/big" + "path/filepath" + "reflect" + "runtime" + "sort" + "strconv" + "strings" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/compiled" + "github.com/consensys/gnark/frontend/cs" + "github.com/consensys/gnark/frontend/schema" + bls12377r1cs "github.com/consensys/gnark/internal/backend/bls12-377/cs" + bls12381r1cs "github.com/consensys/gnark/internal/backend/bls12-381/cs" + bls24315r1cs "github.com/consensys/gnark/internal/backend/bls24-315/cs" + bn254r1cs "github.com/consensys/gnark/internal/backend/bn254/cs" + bw6633r1cs "github.com/consensys/gnark/internal/backend/bw6-633/cs" + bw6761r1cs "github.com/consensys/gnark/internal/backend/bw6-761/cs" + "github.com/consensys/gnark/internal/utils" +) + +func NewCompiler(curve ecc.ID, config frontend.CompileConfig) (frontend.Builder, error) { + return newCompiler(curve, config), nil +} + +type scs struct { + compiled.ConstraintSystem + Constraints []compiled.SparseR1C + + st cs.CoeffTable + config frontend.CompileConfig + + // map for recording boolean constrained variables (to not constrain them twice) + mtBooleans map[int]struct{} +} + +// initialCapacity has quite some impact on frontend performance, especially on large circuits size +// we may want to add build tags to tune that +func newCompiler(curveID ecc.ID, config frontend.CompileConfig) *scs { + system := scs{ + ConstraintSystem: compiled.ConstraintSystem{ + + MDebug: make(map[int]int), + MHints: make(map[int]*compiled.Hint), + }, + mtBooleans: make(map[int]struct{}), + Constraints: make([]compiled.SparseR1C, 0, config.Capacity), + st: cs.NewCoeffTable(), + config: config, + } + + system.Public = make([]string, 0) + system.Secret = make([]string, 0) + + system.CurveID = curveID + + return &system +} + +// addPlonkConstraint creates a constraint of the for al+br+clr+k=0 +//func (system *SparseR1CS) addPlonkConstraint(l, r, o frontend.Variable, cidl, cidr, cidm1, cidm2, cido, k int, debugID ...int) { +func (system *scs) addPlonkConstraint(l, r, o compiled.Term, cidl, cidr, cidm1, cidm2, cido, k int, debugID ...int) { + + if len(debugID) > 0 { + system.MDebug[len(system.Constraints)] = debugID[0] + } + + l.SetCoeffID(cidl) + r.SetCoeffID(cidr) + o.SetCoeffID(cido) + + u := l + v := r + u.SetCoeffID(cidm1) + v.SetCoeffID(cidm2) + + //system.Constraints = append(system.Constraints, compiled.SparseR1C{L: _l, R: _r, O: _o, M: [2]compiled.Term{u, v}, K: k}) + system.Constraints = append(system.Constraints, compiled.SparseR1C{L: l, R: r, O: o, M: [2]compiled.Term{u, v}, K: k}) +} + +// newInternalVariable creates a new wire, appends it on the list of wires of the circuit, sets +// the wire's id to the number of wires, and returns it +func (system *scs) newInternalVariable() compiled.Term { + idx := system.NbInternalVariables + system.NbInternalVariables++ + return compiled.Pack(idx, compiled.CoeffIdOne, schema.Internal) +} + +// AddPublicVariable creates a new Public Variable +func (system *scs) AddPublicVariable(name string) frontend.Variable { + if system.Schema != nil { + panic("do not call AddPublicVariable in circuit.Define()") + } + idx := len(system.Public) + system.Public = append(system.Public, name) + return compiled.Pack(idx, compiled.CoeffIdOne, schema.Public) +} + +// AddSecretVariable creates a new Secret Variable +func (system *scs) AddSecretVariable(name string) frontend.Variable { + if system.Schema != nil { + panic("do not call AddSecretVariable in circuit.Define()") + } + idx := len(system.Secret) + system.Secret = append(system.Secret, name) + return compiled.Pack(idx, compiled.CoeffIdOne, schema.Secret) +} + +// reduces redundancy in linear expression +// It factorizes Variable that appears multiple times with != coeff Ids +// To ensure the determinism in the compile process, Variables are stored as public∥secret∥internal∥unset +// for each visibility, the Variables are sorted from lowest ID to highest ID +func (system *scs) reduce(l compiled.LinearExpression) compiled.LinearExpression { + + // ensure our linear expression is sorted, by visibility and by Variable ID + sort.Sort(l) + + mod := system.CurveID.Info().Fr.Modulus() + c := new(big.Int) + for i := 1; i < len(l); i++ { + pcID, pvID, pVis := l[i-1].Unpack() + ccID, cvID, cVis := l[i].Unpack() + if pVis == cVis && pvID == cvID { + // we have redundancy + c.Add(&system.st.Coeffs[pcID], &system.st.Coeffs[ccID]) + c.Mod(c, mod) + l[i-1].SetCoeffID(system.st.CoeffID(c)) + l = append(l[:i], l[i+1:]...) + i-- + } + } + return l +} + +// to handle wires that don't exist (=coef 0) in a sparse constraint +func (system *scs) zero() compiled.Term { + var a compiled.Term + return a +} + +// IsBoolean returns true if given variable was marked as boolean in the compiler (see MarkBoolean) +// Use with care; variable may not have been **constrained** to be boolean +// This returns true if the v is a constant and v == 0 || v == 1. +func (system *scs) IsBoolean(v frontend.Variable) bool { + if b, ok := system.ConstantValue(v); ok { + return b.IsUint64() && b.Uint64() <= 1 + } + _, ok := system.mtBooleans[int(v.(compiled.Term))] + return ok +} + +// MarkBoolean sets (but do not constraint!) v to be boolean +// This is useful in scenarios where a variable is known to be boolean through a constraint +// that is not api.AssertIsBoolean. If v is a constant, this is a no-op. +func (system *scs) MarkBoolean(v frontend.Variable) { + if b, ok := system.ConstantValue(v); ok { + if !(b.IsUint64() && b.Uint64() <= 1) { + panic("MarkBoolean called a non-boolean constant") + } + } + system.mtBooleans[int(v.(compiled.Term))] = struct{}{} +} + +// checkVariables perform post compilation checks on the Variables +// +// 1. checks that all user inputs are referenced in at least one constraint +// 2. checks that all hints are constrained +func (system *scs) checkVariables() error { + + // TODO @gbotrel add unit test for that. + + cptSecret := len(system.Secret) + cptPublic := len(system.Public) + cptHints := len(system.MHints) + + // compared to R1CS, we may have a circuit which does not have any inputs + // (R1CS always has a constant ONE wire). Check the edge case and omit any + // processing if so. + if cptSecret+cptPublic+cptHints == 0 { + return nil + } + + secretConstrained := make([]bool, cptSecret) + publicConstrained := make([]bool, cptPublic) + + mHintsConstrained := make(map[int]bool) + + // for each constraint, we check the terms and mark our inputs / hints as constrained + processTerm := func(t compiled.Term) { + + // L and M[0] handles the same wire but with a different coeff + visibility := t.VariableVisibility() + vID := t.WireID() + if t.CoeffID() != compiled.CoeffIdZero { + switch visibility { + case schema.Public: + if !publicConstrained[vID] { + publicConstrained[vID] = true + cptPublic-- + } + case schema.Secret: + if !secretConstrained[vID] { + secretConstrained[vID] = true + cptSecret-- + } + case schema.Internal: + if _, ok := system.MHints[vID]; !mHintsConstrained[vID] && ok { + mHintsConstrained[vID] = true + cptHints-- + } + } + } + + } + for _, c := range system.Constraints { + processTerm(c.L) + processTerm(c.R) + processTerm(c.M[0]) + processTerm(c.M[1]) + processTerm(c.O) + if cptHints|cptSecret|cptPublic == 0 { + return nil // we can stop. + } + + } + + // something is a miss, we build the error string + var sbb strings.Builder + if cptSecret != 0 { + sbb.WriteString(strconv.Itoa(cptSecret)) + sbb.WriteString(" unconstrained secret input(s):") + sbb.WriteByte('\n') + for i := 0; i < len(secretConstrained) && cptSecret != 0; i++ { + if !secretConstrained[i] { + sbb.WriteString(system.Secret[i]) + sbb.WriteByte('\n') + cptSecret-- + } + } + sbb.WriteByte('\n') + } + + if cptPublic != 0 { + sbb.WriteString(strconv.Itoa(cptPublic)) + sbb.WriteString(" unconstrained public input(s):") + sbb.WriteByte('\n') + for i := 0; i < len(publicConstrained) && cptPublic != 0; i++ { + if !publicConstrained[i] { + sbb.WriteString(system.Public[i]) + sbb.WriteByte('\n') + cptPublic-- + } + } + sbb.WriteByte('\n') + } + + if cptHints != 0 { + sbb.WriteString(strconv.Itoa(cptHints)) + sbb.WriteString(" unconstrained hints") + sbb.WriteByte('\n') + // TODO we may add more debug info here → idea, in NewHint, take the debug stack, and store in the hint map some + // debugInfo to find where a hint was declared (and not constrained) + } + return errors.New(sbb.String()) +} + +var tVariable reflect.Type + +func init() { + tVariable = reflect.ValueOf(struct{ A frontend.Variable }{}).FieldByName("A").Type() +} + +func (cs *scs) Compile() (frontend.CompiledConstraintSystem, error) { + + // ensure all inputs and hints are constrained + if !cs.config.IgnoreUnconstrainedInputs { + if err := cs.checkVariables(); err != nil { + return nil, err + } + } + + res := compiled.SparseR1CS{ + ConstraintSystem: cs.ConstraintSystem, + Constraints: cs.Constraints, + } + res.NbPublicVariables = len(cs.Public) + res.NbSecretVariables = len(cs.Secret) + + // Logs, DebugInfo and hints are copied, the only thing that will change + // is that ID of the wires will be offseted to take into account the final wire vector ordering + // that is: public wires | secret wires | internal wires + + // shift variable ID + // we want publicWires | privateWires | internalWires + shiftVID := func(oldID int, visibility schema.Visibility) int { + switch visibility { + case schema.Internal: + return oldID + res.NbPublicVariables + res.NbSecretVariables + case schema.Public: + return oldID + case schema.Secret: + return oldID + res.NbPublicVariables + default: + return oldID + } + } + + offsetTermID := func(t *compiled.Term) { + _, VID, visibility := t.Unpack() + t.SetWireID(shiftVID(VID, visibility)) + } + + // offset the IDs of all constraints so that the variables are + // numbered like this: [publicVariables | secretVariables | internalVariables ] + for i := 0; i < len(res.Constraints); i++ { + r1c := &res.Constraints[i] + offsetTermID(&r1c.L) + offsetTermID(&r1c.R) + offsetTermID(&r1c.O) + offsetTermID(&r1c.M[0]) + offsetTermID(&r1c.M[1]) + } + + // we need to offset the ids in Logs & DebugInfo + for i := 0; i < len(cs.Logs); i++ { + for j := 0; j < len(res.Logs[i].ToResolve); j++ { + offsetTermID(&res.Logs[i].ToResolve[j]) + } + } + for i := 0; i < len(cs.DebugInfo); i++ { + for j := 0; j < len(res.DebugInfo[i].ToResolve); j++ { + offsetTermID(&res.DebugInfo[i].ToResolve[j]) + } + } + + // we need to offset the ids in the hints + shiftedMap := make(map[int]*compiled.Hint) +HINTLOOP: + for _, hint := range cs.MHints { + ws := make([]int, len(hint.Wires)) + // we set for all outputs in shiftedMap. If one shifted output + // is in shiftedMap, then all are + for i, vID := range hint.Wires { + ws[i] = shiftVID(vID, schema.Internal) + if _, ok := shiftedMap[ws[i]]; i == 0 && ok { + continue HINTLOOP + } + } + inputs := make([]interface{}, len(hint.Inputs)) + copy(inputs, hint.Inputs) + for j := 0; j < len(inputs); j++ { + switch t := inputs[j].(type) { + case compiled.Term: + offsetTermID(&t) + inputs[j] = t // TODO check if we can remove it + default: + inputs[j] = t + } + } + ch := &compiled.Hint{ID: hint.ID, Inputs: inputs, Wires: ws} + for _, vID := range ws { + shiftedMap[vID] = ch + } + } + res.MHints = shiftedMap + + // build levels + res.Levels = buildLevels(res) + + switch cs.CurveID { + case ecc.BLS12_377: + return bls12377r1cs.NewSparseR1CS(res, cs.st.Coeffs), nil + case ecc.BLS12_381: + return bls12381r1cs.NewSparseR1CS(res, cs.st.Coeffs), nil + case ecc.BN254: + return bn254r1cs.NewSparseR1CS(res, cs.st.Coeffs), nil + case ecc.BW6_761: + return bw6761r1cs.NewSparseR1CS(res, cs.st.Coeffs), nil + case ecc.BLS24_315: + return bls24315r1cs.NewSparseR1CS(res, cs.st.Coeffs), nil + case ecc.BW6_633: + return bw6633r1cs.NewSparseR1CS(res, cs.st.Coeffs), nil + default: + panic("unknown curveID") + } + +} + +func (cs *scs) SetSchema(s *schema.Schema) { + if cs.Schema != nil { + panic("SetSchema called multiple times") + } + cs.Schema = s +} + +func buildLevels(ccs compiled.SparseR1CS) [][]int { + + b := levelBuilder{ + mWireToNode: make(map[int]int, ccs.NbInternalVariables), // at which node we resolved which wire + nodeLevels: make([]int, len(ccs.Constraints)), // level of a node + mLevels: make(map[int]int), // level counts + ccs: ccs, + nbInputs: ccs.NbPublicVariables + ccs.NbSecretVariables, + } + + // for each constraint, we're going to find its direct dependencies + // that is, wires (solved by previous constraints) on which it depends + // each of these dependencies is tagged with a level + // current constraint will be tagged with max(level) + 1 + for cID, c := range ccs.Constraints { + + b.nodeLevel = 0 + + b.processTerm(c.L, cID) + b.processTerm(c.R, cID) + b.processTerm(c.O, cID) + + b.nodeLevels[cID] = b.nodeLevel + b.mLevels[b.nodeLevel]++ + + } + + levels := make([][]int, len(b.mLevels)) + for i := 0; i < len(levels); i++ { + // allocate memory + levels[i] = make([]int, 0, b.mLevels[i]) + } + + for n, l := range b.nodeLevels { + levels[l] = append(levels[l], n) + } + + return levels +} + +type levelBuilder struct { + ccs compiled.SparseR1CS + nbInputs int + + mWireToNode map[int]int // at which node we resolved which wire + nodeLevels []int // level per node + mLevels map[int]int // number of constraint per level + + nodeLevel int // current level +} + +func (b *levelBuilder) processTerm(t compiled.Term, cID int) { + wID := t.WireID() + if wID < b.nbInputs { + // it's a input, we ignore it + return + } + + // if we know a which constraint solves this wire, then it's a dependency + n, ok := b.mWireToNode[wID] + if ok { + if n != cID { // can happen with hints... + // we add a dependency, check if we need to increment our current level + if b.nodeLevels[n] >= b.nodeLevel { + b.nodeLevel = b.nodeLevels[n] + 1 // we are at the next level at least since we depend on it + } + } + return + } + + // check if it's a hint and mark all the output wires + if h, ok := b.ccs.MHints[wID]; ok { + + for _, in := range h.Inputs { + switch t := in.(type) { + case compiled.LinearExpression: + for _, tt := range t { + b.processTerm(tt, cID) + } + case compiled.Term: + b.processTerm(t, cID) + } + } + + for _, hwid := range h.Wires { + b.mWireToNode[hwid] = cID + } + + return + } + + // mark this wire solved by current node + b.mWireToNode[wID] = cID + +} + +// ConstantValue returns the big.Int value of v. It +// panics if v.IsConstant() == false +func (system *scs) ConstantValue(v frontend.Variable) (*big.Int, bool) { + switch t := v.(type) { + case compiled.Term: + return nil, false + default: + res := utils.FromInterface(t) + return &res, true + } +} + +func (system *scs) Backend() backend.ID { + return backend.PLONK +} + +// Tag creates a tag at a given place in a circuit. The state of the tag may contain informations needed to +// measure constraints, variables and coefficients creations through AddCounter +func (system *scs) Tag(name string) frontend.Tag { + _, file, line, _ := runtime.Caller(1) + + return frontend.Tag{ + Name: fmt.Sprintf("%s[%s:%d]", name, filepath.Base(file), line), + VID: system.NbInternalVariables, + CID: len(system.Constraints), + } +} + +// AddCounter measures the number of constraints, variables and coefficients created between two tags +// note that the PlonK statistics are contextual since there is a post-compile phase where linear expressions +// are factorized. That is, measuring 2 times the "repeating" piece of circuit may give less constraints the second time +func (system *scs) AddCounter(from, to frontend.Tag) { + system.Counters = append(system.Counters, compiled.Counter{ + From: from.Name, + To: to.Name, + NbVariables: to.VID - from.VID, + NbConstraints: to.CID - from.CID, + CurveID: system.CurveID, + BackendID: backend.PLONK, + }) +} + +// NewHint initializes internal variables whose value will be evaluated using +// the provided hint function at run time from the inputs. Inputs must be either +// variables or convertible to *big.Int. The function returns an error if the +// number of inputs is not compatible with f. +// +// The hint function is provided at the proof creation time and is not embedded +// into the circuit. From the backend point of view, the variable returned by +// the hint function is equivalent to the user-supplied witness, but its actual +// value is assigned by the solver, not the caller. +// +// No new constraints are added to the newly created wire and must be added +// manually in the circuit. Failing to do so leads to solver failure. +func (system *scs) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { + + if nbOutputs <= 0 { + return nil, fmt.Errorf("hint function must return at least one output") + } + + hintInputs := make([]interface{}, len(inputs)) + + // ensure inputs are set and pack them in a []uint64 + for i, in := range inputs { + switch t := in.(type) { + case compiled.Term: + hintInputs[i] = t + default: + hintInputs[i] = utils.FromInterface(in) + } + } + + // prepare wires + varIDs := make([]int, nbOutputs) + res := make([]frontend.Variable, len(varIDs)) + for i := range varIDs { + r := system.newInternalVariable() + _, vID, _ := r.Unpack() + varIDs[i] = vID + res[i] = r + } + + ch := &compiled.Hint{ID: f.UUID(), Inputs: hintInputs, Wires: varIDs} + for _, vID := range varIDs { + system.MHints[vID] = ch + } + + return res, nil +} + +// returns in split into a slice of compiledTerm and the sum of all constants in in as a bigInt +func (system *scs) filterConstantSum(in []frontend.Variable) (compiled.LinearExpression, big.Int) { + res := make(compiled.LinearExpression, 0, len(in)) + var b big.Int + for i := 0; i < len(in); i++ { + switch t := in[i].(type) { + case compiled.Term: + res = append(res, t) + default: + n := utils.FromInterface(t) + b.Add(&b, &n) + } + } + return res, b +} + +// returns in split into a slice of compiledTerm and the product of all constants in in as a bigInt +func (system *scs) filterConstantProd(in []frontend.Variable) (compiled.LinearExpression, big.Int) { + res := make(compiled.LinearExpression, 0, len(in)) + var b big.Int + b.SetInt64(1) + for i := 0; i < len(in); i++ { + switch t := in[i].(type) { + case compiled.Term: + res = append(res, t) + default: + n := utils.FromInterface(t) + b.Mul(&b, &n).Mod(&b, system.CurveID.Info().Fr.Modulus()) + } + } + return res, b +} + +func (system *scs) splitSum(acc compiled.Term, r compiled.LinearExpression) compiled.Term { + + // floor case + if len(r) == 0 { + return acc + } + + cl, _, _ := acc.Unpack() + cr, _, _ := r[0].Unpack() + o := system.newInternalVariable() + system.addPlonkConstraint(acc, r[0], o, cl, cr, compiled.CoeffIdZero, compiled.CoeffIdZero, compiled.CoeffIdMinusOne, compiled.CoeffIdZero) + return system.splitSum(o, r[1:]) +} + +func (system *scs) splitProd(acc compiled.Term, r compiled.LinearExpression) compiled.Term { + + // floor case + if len(r) == 0 { + return acc + } + + cl, _, _ := acc.Unpack() + cr, _, _ := r[0].Unpack() + o := system.newInternalVariable() + system.addPlonkConstraint(acc, r[0], o, compiled.CoeffIdZero, compiled.CoeffIdZero, cl, cr, compiled.CoeffIdMinusOne, compiled.CoeffIdZero) + return system.splitProd(o, r[1:]) +} diff --git a/frontend/registry.go b/frontend/registry.go deleted file mode 100644 index 6c2b3b7642..0000000000 --- a/frontend/registry.go +++ /dev/null @@ -1,36 +0,0 @@ -package frontend - -import ( - "fmt" - "sync" - - "github.com/consensys/gnark/backend" -) - -var ( - backends = make(map[backend.ID]NewBuilder) - backendsM sync.RWMutex -) - -// RegisterDefaultBuilder registers a frontend f for a backend b. This registration -// ensures that a correct frontend system is chosen for a specific backend when -// compiling a circuit. The method does not check that the compiler for that -// frontend is already registered and the compiler is looked up during compile -// time. It is an error to double-assign a frontend to a single backend and the -// mehod panics. -// -// /!\ This is highly experimental and may change in upcoming releases /!\ -func RegisterDefaultBuilder(b backend.ID, builder NewBuilder) { - if b == backend.UNKNOWN { - panic("can not assign builder to unknown backend") - } - - // a frontend may be assigned before a compiler to that frontend is - // registered. we perform frontend compiler lookup during compilation. - backendsM.Lock() - defer backendsM.Unlock() - if _, ok := backends[b]; ok { - panic(fmt.Sprintf("double frontend registration for backend '%s'", b)) - } - backends[b] = builder -} diff --git a/internal/backend/bls12-377/cs/r1cs.go b/internal/backend/bls12-377/cs/r1cs.go index 1313d715e4..85b3b3694a 100644 --- a/internal/backend/bls12-377/cs/r1cs.go +++ b/internal/backend/bls12-377/cs/r1cs.go @@ -28,8 +28,8 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc" @@ -302,15 +302,15 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr. return nil } - if err := processLExp(r.L.LinExp, a, 1); err != nil { + if err := processLExp(r.L, a, 1); err != nil { return err } - if err := processLExp(r.R.LinExp, b, 2); err != nil { + if err := processLExp(r.R, b, 2); err != nil { return err } - if err := processLExp(r.O.LinExp, c, 3); err != nil { + if err := processLExp(r.O, c, 3); err != nil { return err } @@ -387,11 +387,11 @@ func (cs *R1CS) GetConstraints() [][]string { return r } -func (cs *R1CS) vtoString(l compiled.Variable) string { +func (cs *R1CS) vtoString(l compiled.LinearExpression) string { var sbb strings.Builder - for i := 0; i < len(l.LinExp); i++ { - cs.termToString(l.LinExp[i], &sbb) - if i+1 < len(l.LinExp) { + for i := 0; i < len(l); i++ { + cs.termToString(l[i], &sbb) + if i+1 < len(l) { sbb.WriteString(" + ") } } diff --git a/internal/backend/bls12-377/cs/r1cs_sparse.go b/internal/backend/bls12-377/cs/r1cs_sparse.go index 8c5d2c087b..c1d6a0e887 100644 --- a/internal/backend/bls12-377/cs/r1cs_sparse.go +++ b/internal/backend/bls12-377/cs/r1cs_sparse.go @@ -30,8 +30,8 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" diff --git a/internal/backend/bls12-377/cs/r1cs_test.go b/internal/backend/bls12-377/cs/r1cs_test.go index e63b16c7d6..8af814bc99 100644 --- a/internal/backend/bls12-377/cs/r1cs_test.go +++ b/internal/backend/bls12-377/cs/r1cs_test.go @@ -19,7 +19,6 @@ package cs_test import ( "bytes" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/internal/backend/circuits" @@ -37,7 +36,7 @@ func TestSerialization(t *testing.T) { t.Run(name, func(t *testing.T) { tc := circuits.Circuits[name] - r1cs1, err := frontend.Compile(ecc.BLS12_377, backend.UNKNOWN, tc.Circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs1, err := frontend.Compile(ecc.BLS12_377, r1cs.NewCompiler, tc.Circuit) if err != nil { t.Fatal(err) } @@ -46,7 +45,7 @@ func TestSerialization(t *testing.T) { } // copmpile a second time to ensure determinism - r1cs2, err := frontend.Compile(ecc.BLS12_377, backend.UNKNOWN, tc.Circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs2, err := frontend.Compile(ecc.BLS12_377, r1cs.NewCompiler, tc.Circuit) if err != nil { t.Fatal(err) } @@ -135,7 +134,7 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { var c circuit - ccs, err := frontend.Compile(ecc.BLS12_377, backend.UNKNOWN, &c, frontend.WithBuilder(r1cs.NewBuilder)) + ccs, err := frontend.Compile(ecc.BLS12_377, r1cs.NewCompiler, &c) if err != nil { b.Fatal(err) } diff --git a/internal/backend/bls12-377/cs/solution.go b/internal/backend/bls12-377/cs/solution.go index 6911e72ab7..b3d63f66d9 100644 --- a/internal/backend/bls12-377/cs/solution.go +++ b/internal/backend/bls12-377/cs/solution.go @@ -24,8 +24,8 @@ import ( "sync/atomic" "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" @@ -148,7 +148,7 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { // tmp IO big int memory nbInputs := len(h.Inputs) - nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs)) + nbOutputs := len(h.Wires) // m := len(s.tmpHintsIO) // if m < (nbInputs + nbOutputs) { // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...) @@ -170,9 +170,6 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { for i := 0; i < nbInputs; i++ { switch t := h.Inputs[i].(type) { - case compiled.Variable: - v := s.computeLinearExpression(t.LinExp) - v.ToBigIntRegular(inputs[i]) case compiled.LinearExpression: v := s.computeLinearExpression(t) v.ToBigIntRegular(inputs[i]) diff --git a/internal/backend/bls12-377/groth16/groth16_test.go b/internal/backend/bls12-377/groth16/groth16_test.go index a467692558..ebcfb09f76 100644 --- a/internal/backend/bls12-377/groth16/groth16_test.go +++ b/internal/backend/bls12-377/groth16/groth16_test.go @@ -58,7 +58,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit) { circuit := refCircuit{ nbConstraints: nbConstraints, } - r1cs, err := frontend.Compile(curve.ID, backend.UNKNOWN, &circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs, err := frontend.Compile(curve.ID, r1cs.NewCompiler, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bls12-377/groth16/setup.go b/internal/backend/bls12-377/groth16/setup.go index 21fe78713c..95112cddee 100644 --- a/internal/backend/bls12-377/groth16/setup.go +++ b/internal/backend/bls12-377/groth16/setup.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" - "github.com/consensys/gnark/internal/backend/compiled" + "github.com/consensys/gnark/frontend/compiled" "math/big" "math/bits" ) @@ -336,13 +336,13 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // A, B or C at the indice of the variable for i, c := range r1cs.Constraints { - for _, t := range c.L.LinExp { + for _, t := range c.L { accumulate(&A[t.WireID()], t, &L) } - for _, t := range c.R.LinExp { + for _, t := range c.R { accumulate(&B[t.WireID()], t, &L) } - for _, t := range c.O.LinExp { + for _, t := range c.O { accumulate(&C[t.WireID()], t, &L) } @@ -491,10 +491,10 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { A := make([]bool, nbWires) B := make([]bool, nbWires) for _, c := range r1cs.Constraints { - for _, t := range c.L.LinExp { + for _, t := range c.L { A[t.WireID()] = true } - for _, t := range c.R.LinExp { + for _, t := range c.R { B[t.WireID()] = true } } diff --git a/internal/backend/bls12-377/plonk/plonk_test.go b/internal/backend/bls12-377/plonk/plonk_test.go index 8c9affeab7..2076347fd5 100644 --- a/internal/backend/bls12-377/plonk/plonk_test.go +++ b/internal/backend/bls12-377/plonk/plonk_test.go @@ -36,7 +36,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/plonk" + "github.com/consensys/gnark/frontend/cs/scs" ) //--------------------// @@ -62,7 +62,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit, *k circuit := refCircuit{ nbConstraints: nbConstraints, } - ccs, err := frontend.Compile(curve.ID, backend.UNKNOWN, &circuit, frontend.WithBuilder(plonk.NewBuilder)) + ccs, err := frontend.Compile(curve.ID, scs.NewCompiler, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bls12-381/cs/r1cs.go b/internal/backend/bls12-381/cs/r1cs.go index 0a32652ad9..4beef74f02 100644 --- a/internal/backend/bls12-381/cs/r1cs.go +++ b/internal/backend/bls12-381/cs/r1cs.go @@ -28,8 +28,8 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc" @@ -302,15 +302,15 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr. return nil } - if err := processLExp(r.L.LinExp, a, 1); err != nil { + if err := processLExp(r.L, a, 1); err != nil { return err } - if err := processLExp(r.R.LinExp, b, 2); err != nil { + if err := processLExp(r.R, b, 2); err != nil { return err } - if err := processLExp(r.O.LinExp, c, 3); err != nil { + if err := processLExp(r.O, c, 3); err != nil { return err } @@ -387,11 +387,11 @@ func (cs *R1CS) GetConstraints() [][]string { return r } -func (cs *R1CS) vtoString(l compiled.Variable) string { +func (cs *R1CS) vtoString(l compiled.LinearExpression) string { var sbb strings.Builder - for i := 0; i < len(l.LinExp); i++ { - cs.termToString(l.LinExp[i], &sbb) - if i+1 < len(l.LinExp) { + for i := 0; i < len(l); i++ { + cs.termToString(l[i], &sbb) + if i+1 < len(l) { sbb.WriteString(" + ") } } diff --git a/internal/backend/bls12-381/cs/r1cs_sparse.go b/internal/backend/bls12-381/cs/r1cs_sparse.go index 106e6eb0eb..a3ce43cd74 100644 --- a/internal/backend/bls12-381/cs/r1cs_sparse.go +++ b/internal/backend/bls12-381/cs/r1cs_sparse.go @@ -30,8 +30,8 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" diff --git a/internal/backend/bls12-381/cs/r1cs_test.go b/internal/backend/bls12-381/cs/r1cs_test.go index a96a65d21c..71ba350ad0 100644 --- a/internal/backend/bls12-381/cs/r1cs_test.go +++ b/internal/backend/bls12-381/cs/r1cs_test.go @@ -19,7 +19,6 @@ package cs_test import ( "bytes" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/internal/backend/circuits" @@ -37,7 +36,7 @@ func TestSerialization(t *testing.T) { t.Run(name, func(t *testing.T) { tc := circuits.Circuits[name] - r1cs1, err := frontend.Compile(ecc.BLS12_381, backend.UNKNOWN, tc.Circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs1, err := frontend.Compile(ecc.BLS12_381, r1cs.NewCompiler, tc.Circuit) if err != nil { t.Fatal(err) } @@ -46,7 +45,7 @@ func TestSerialization(t *testing.T) { } // copmpile a second time to ensure determinism - r1cs2, err := frontend.Compile(ecc.BLS12_381, backend.UNKNOWN, tc.Circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs2, err := frontend.Compile(ecc.BLS12_381, r1cs.NewCompiler, tc.Circuit) if err != nil { t.Fatal(err) } @@ -135,7 +134,7 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { var c circuit - ccs, err := frontend.Compile(ecc.BLS12_381, backend.UNKNOWN, &c, frontend.WithBuilder(r1cs.NewBuilder)) + ccs, err := frontend.Compile(ecc.BLS12_381, r1cs.NewCompiler, &c) if err != nil { b.Fatal(err) } diff --git a/internal/backend/bls12-381/cs/solution.go b/internal/backend/bls12-381/cs/solution.go index 9d630c8153..d577ae48d9 100644 --- a/internal/backend/bls12-381/cs/solution.go +++ b/internal/backend/bls12-381/cs/solution.go @@ -24,8 +24,8 @@ import ( "sync/atomic" "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" @@ -148,7 +148,7 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { // tmp IO big int memory nbInputs := len(h.Inputs) - nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs)) + nbOutputs := len(h.Wires) // m := len(s.tmpHintsIO) // if m < (nbInputs + nbOutputs) { // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...) @@ -170,9 +170,6 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { for i := 0; i < nbInputs; i++ { switch t := h.Inputs[i].(type) { - case compiled.Variable: - v := s.computeLinearExpression(t.LinExp) - v.ToBigIntRegular(inputs[i]) case compiled.LinearExpression: v := s.computeLinearExpression(t) v.ToBigIntRegular(inputs[i]) diff --git a/internal/backend/bls12-381/groth16/groth16_test.go b/internal/backend/bls12-381/groth16/groth16_test.go index 721e57e6bf..1c229f9ac8 100644 --- a/internal/backend/bls12-381/groth16/groth16_test.go +++ b/internal/backend/bls12-381/groth16/groth16_test.go @@ -58,7 +58,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit) { circuit := refCircuit{ nbConstraints: nbConstraints, } - r1cs, err := frontend.Compile(curve.ID, backend.UNKNOWN, &circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs, err := frontend.Compile(curve.ID, r1cs.NewCompiler, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bls12-381/groth16/setup.go b/internal/backend/bls12-381/groth16/setup.go index d481f8ad0d..b76aa9c87f 100644 --- a/internal/backend/bls12-381/groth16/setup.go +++ b/internal/backend/bls12-381/groth16/setup.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" - "github.com/consensys/gnark/internal/backend/compiled" + "github.com/consensys/gnark/frontend/compiled" "math/big" "math/bits" ) @@ -336,13 +336,13 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // A, B or C at the indice of the variable for i, c := range r1cs.Constraints { - for _, t := range c.L.LinExp { + for _, t := range c.L { accumulate(&A[t.WireID()], t, &L) } - for _, t := range c.R.LinExp { + for _, t := range c.R { accumulate(&B[t.WireID()], t, &L) } - for _, t := range c.O.LinExp { + for _, t := range c.O { accumulate(&C[t.WireID()], t, &L) } @@ -491,10 +491,10 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { A := make([]bool, nbWires) B := make([]bool, nbWires) for _, c := range r1cs.Constraints { - for _, t := range c.L.LinExp { + for _, t := range c.L { A[t.WireID()] = true } - for _, t := range c.R.LinExp { + for _, t := range c.R { B[t.WireID()] = true } } diff --git a/internal/backend/bls12-381/plonk/plonk_test.go b/internal/backend/bls12-381/plonk/plonk_test.go index 5332ba3016..712add07b1 100644 --- a/internal/backend/bls12-381/plonk/plonk_test.go +++ b/internal/backend/bls12-381/plonk/plonk_test.go @@ -36,7 +36,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/plonk" + "github.com/consensys/gnark/frontend/cs/scs" ) //--------------------// @@ -62,7 +62,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit, *k circuit := refCircuit{ nbConstraints: nbConstraints, } - ccs, err := frontend.Compile(curve.ID, backend.UNKNOWN, &circuit, frontend.WithBuilder(plonk.NewBuilder)) + ccs, err := frontend.Compile(curve.ID, scs.NewCompiler, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bls24-315/cs/r1cs.go b/internal/backend/bls24-315/cs/r1cs.go index 964acf1f50..a66744f5f0 100644 --- a/internal/backend/bls24-315/cs/r1cs.go +++ b/internal/backend/bls24-315/cs/r1cs.go @@ -28,8 +28,8 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc" @@ -302,15 +302,15 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr. return nil } - if err := processLExp(r.L.LinExp, a, 1); err != nil { + if err := processLExp(r.L, a, 1); err != nil { return err } - if err := processLExp(r.R.LinExp, b, 2); err != nil { + if err := processLExp(r.R, b, 2); err != nil { return err } - if err := processLExp(r.O.LinExp, c, 3); err != nil { + if err := processLExp(r.O, c, 3); err != nil { return err } @@ -387,11 +387,11 @@ func (cs *R1CS) GetConstraints() [][]string { return r } -func (cs *R1CS) vtoString(l compiled.Variable) string { +func (cs *R1CS) vtoString(l compiled.LinearExpression) string { var sbb strings.Builder - for i := 0; i < len(l.LinExp); i++ { - cs.termToString(l.LinExp[i], &sbb) - if i+1 < len(l.LinExp) { + for i := 0; i < len(l); i++ { + cs.termToString(l[i], &sbb) + if i+1 < len(l) { sbb.WriteString(" + ") } } diff --git a/internal/backend/bls24-315/cs/r1cs_sparse.go b/internal/backend/bls24-315/cs/r1cs_sparse.go index 9ecd27fb26..8451a8e5cf 100644 --- a/internal/backend/bls24-315/cs/r1cs_sparse.go +++ b/internal/backend/bls24-315/cs/r1cs_sparse.go @@ -30,8 +30,8 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" diff --git a/internal/backend/bls24-315/cs/r1cs_test.go b/internal/backend/bls24-315/cs/r1cs_test.go index 931504b2f7..1fb373ed46 100644 --- a/internal/backend/bls24-315/cs/r1cs_test.go +++ b/internal/backend/bls24-315/cs/r1cs_test.go @@ -19,7 +19,6 @@ package cs_test import ( "bytes" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/internal/backend/circuits" @@ -37,7 +36,7 @@ func TestSerialization(t *testing.T) { t.Run(name, func(t *testing.T) { tc := circuits.Circuits[name] - r1cs1, err := frontend.Compile(ecc.BLS24_315, backend.UNKNOWN, tc.Circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs1, err := frontend.Compile(ecc.BLS24_315, r1cs.NewCompiler, tc.Circuit) if err != nil { t.Fatal(err) } @@ -46,7 +45,7 @@ func TestSerialization(t *testing.T) { } // copmpile a second time to ensure determinism - r1cs2, err := frontend.Compile(ecc.BLS24_315, backend.UNKNOWN, tc.Circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs2, err := frontend.Compile(ecc.BLS24_315, r1cs.NewCompiler, tc.Circuit) if err != nil { t.Fatal(err) } @@ -135,7 +134,7 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { var c circuit - ccs, err := frontend.Compile(ecc.BLS24_315, backend.UNKNOWN, &c, frontend.WithBuilder(r1cs.NewBuilder)) + ccs, err := frontend.Compile(ecc.BLS24_315, r1cs.NewCompiler, &c) if err != nil { b.Fatal(err) } diff --git a/internal/backend/bls24-315/cs/solution.go b/internal/backend/bls24-315/cs/solution.go index e215ac343a..8e4d24cd8a 100644 --- a/internal/backend/bls24-315/cs/solution.go +++ b/internal/backend/bls24-315/cs/solution.go @@ -24,8 +24,8 @@ import ( "sync/atomic" "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" @@ -148,7 +148,7 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { // tmp IO big int memory nbInputs := len(h.Inputs) - nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs)) + nbOutputs := len(h.Wires) // m := len(s.tmpHintsIO) // if m < (nbInputs + nbOutputs) { // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...) @@ -170,9 +170,6 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { for i := 0; i < nbInputs; i++ { switch t := h.Inputs[i].(type) { - case compiled.Variable: - v := s.computeLinearExpression(t.LinExp) - v.ToBigIntRegular(inputs[i]) case compiled.LinearExpression: v := s.computeLinearExpression(t) v.ToBigIntRegular(inputs[i]) diff --git a/internal/backend/bls24-315/groth16/groth16_test.go b/internal/backend/bls24-315/groth16/groth16_test.go index af5f319940..ed67ed605d 100644 --- a/internal/backend/bls24-315/groth16/groth16_test.go +++ b/internal/backend/bls24-315/groth16/groth16_test.go @@ -58,7 +58,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit) { circuit := refCircuit{ nbConstraints: nbConstraints, } - r1cs, err := frontend.Compile(curve.ID, backend.UNKNOWN, &circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs, err := frontend.Compile(curve.ID, r1cs.NewCompiler, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bls24-315/groth16/setup.go b/internal/backend/bls24-315/groth16/setup.go index 596e7786fc..ad5252165e 100644 --- a/internal/backend/bls24-315/groth16/setup.go +++ b/internal/backend/bls24-315/groth16/setup.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" - "github.com/consensys/gnark/internal/backend/compiled" + "github.com/consensys/gnark/frontend/compiled" "math/big" "math/bits" ) @@ -336,13 +336,13 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // A, B or C at the indice of the variable for i, c := range r1cs.Constraints { - for _, t := range c.L.LinExp { + for _, t := range c.L { accumulate(&A[t.WireID()], t, &L) } - for _, t := range c.R.LinExp { + for _, t := range c.R { accumulate(&B[t.WireID()], t, &L) } - for _, t := range c.O.LinExp { + for _, t := range c.O { accumulate(&C[t.WireID()], t, &L) } @@ -491,10 +491,10 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { A := make([]bool, nbWires) B := make([]bool, nbWires) for _, c := range r1cs.Constraints { - for _, t := range c.L.LinExp { + for _, t := range c.L { A[t.WireID()] = true } - for _, t := range c.R.LinExp { + for _, t := range c.R { B[t.WireID()] = true } } diff --git a/internal/backend/bls24-315/plonk/plonk_test.go b/internal/backend/bls24-315/plonk/plonk_test.go index 71378ce85f..dcc755d9d4 100644 --- a/internal/backend/bls24-315/plonk/plonk_test.go +++ b/internal/backend/bls24-315/plonk/plonk_test.go @@ -36,7 +36,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/plonk" + "github.com/consensys/gnark/frontend/cs/scs" ) //--------------------// @@ -62,7 +62,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit, *k circuit := refCircuit{ nbConstraints: nbConstraints, } - ccs, err := frontend.Compile(curve.ID, backend.UNKNOWN, &circuit, frontend.WithBuilder(plonk.NewBuilder)) + ccs, err := frontend.Compile(curve.ID, scs.NewCompiler, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bn254/cs/r1cs.go b/internal/backend/bn254/cs/r1cs.go index c44696bae1..4db2916282 100644 --- a/internal/backend/bn254/cs/r1cs.go +++ b/internal/backend/bn254/cs/r1cs.go @@ -28,8 +28,8 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc" @@ -302,15 +302,15 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr. return nil } - if err := processLExp(r.L.LinExp, a, 1); err != nil { + if err := processLExp(r.L, a, 1); err != nil { return err } - if err := processLExp(r.R.LinExp, b, 2); err != nil { + if err := processLExp(r.R, b, 2); err != nil { return err } - if err := processLExp(r.O.LinExp, c, 3); err != nil { + if err := processLExp(r.O, c, 3); err != nil { return err } @@ -387,11 +387,11 @@ func (cs *R1CS) GetConstraints() [][]string { return r } -func (cs *R1CS) vtoString(l compiled.Variable) string { +func (cs *R1CS) vtoString(l compiled.LinearExpression) string { var sbb strings.Builder - for i := 0; i < len(l.LinExp); i++ { - cs.termToString(l.LinExp[i], &sbb) - if i+1 < len(l.LinExp) { + for i := 0; i < len(l); i++ { + cs.termToString(l[i], &sbb) + if i+1 < len(l) { sbb.WriteString(" + ") } } diff --git a/internal/backend/bn254/cs/r1cs_sparse.go b/internal/backend/bn254/cs/r1cs_sparse.go index 8eab1f901f..0fa0d1523a 100644 --- a/internal/backend/bn254/cs/r1cs_sparse.go +++ b/internal/backend/bn254/cs/r1cs_sparse.go @@ -30,8 +30,8 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc/bn254/fr" diff --git a/internal/backend/bn254/cs/r1cs_test.go b/internal/backend/bn254/cs/r1cs_test.go index c2eb0a7663..7c9806da35 100644 --- a/internal/backend/bn254/cs/r1cs_test.go +++ b/internal/backend/bn254/cs/r1cs_test.go @@ -19,7 +19,6 @@ package cs_test import ( "bytes" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/internal/backend/circuits" @@ -37,7 +36,7 @@ func TestSerialization(t *testing.T) { t.Run(name, func(t *testing.T) { tc := circuits.Circuits[name] - r1cs1, err := frontend.Compile(ecc.BN254, backend.UNKNOWN, tc.Circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs1, err := frontend.Compile(ecc.BN254, r1cs.NewCompiler, tc.Circuit) if err != nil { t.Fatal(err) } @@ -46,7 +45,7 @@ func TestSerialization(t *testing.T) { } // copmpile a second time to ensure determinism - r1cs2, err := frontend.Compile(ecc.BN254, backend.UNKNOWN, tc.Circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs2, err := frontend.Compile(ecc.BN254, r1cs.NewCompiler, tc.Circuit) if err != nil { t.Fatal(err) } @@ -135,7 +134,7 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { var c circuit - ccs, err := frontend.Compile(ecc.BN254, backend.UNKNOWN, &c, frontend.WithBuilder(r1cs.NewBuilder)) + ccs, err := frontend.Compile(ecc.BN254, r1cs.NewCompiler, &c) if err != nil { b.Fatal(err) } diff --git a/internal/backend/bn254/cs/solution.go b/internal/backend/bn254/cs/solution.go index 46cb2eb6af..1d8d337c0c 100644 --- a/internal/backend/bn254/cs/solution.go +++ b/internal/backend/bn254/cs/solution.go @@ -24,8 +24,8 @@ import ( "sync/atomic" "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark-crypto/ecc/bn254/fr" @@ -148,7 +148,7 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { // tmp IO big int memory nbInputs := len(h.Inputs) - nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs)) + nbOutputs := len(h.Wires) // m := len(s.tmpHintsIO) // if m < (nbInputs + nbOutputs) { // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...) @@ -170,9 +170,6 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { for i := 0; i < nbInputs; i++ { switch t := h.Inputs[i].(type) { - case compiled.Variable: - v := s.computeLinearExpression(t.LinExp) - v.ToBigIntRegular(inputs[i]) case compiled.LinearExpression: v := s.computeLinearExpression(t) v.ToBigIntRegular(inputs[i]) diff --git a/internal/backend/bn254/groth16/groth16_test.go b/internal/backend/bn254/groth16/groth16_test.go index 59667f5ce2..1fc0674145 100644 --- a/internal/backend/bn254/groth16/groth16_test.go +++ b/internal/backend/bn254/groth16/groth16_test.go @@ -58,7 +58,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit) { circuit := refCircuit{ nbConstraints: nbConstraints, } - r1cs, err := frontend.Compile(curve.ID, backend.UNKNOWN, &circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs, err := frontend.Compile(curve.ID, r1cs.NewCompiler, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bn254/groth16/setup.go b/internal/backend/bn254/groth16/setup.go index 334461b25b..2ed4f2e6ec 100644 --- a/internal/backend/bn254/groth16/setup.go +++ b/internal/backend/bn254/groth16/setup.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" - "github.com/consensys/gnark/internal/backend/compiled" + "github.com/consensys/gnark/frontend/compiled" "math/big" "math/bits" ) @@ -336,13 +336,13 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // A, B or C at the indice of the variable for i, c := range r1cs.Constraints { - for _, t := range c.L.LinExp { + for _, t := range c.L { accumulate(&A[t.WireID()], t, &L) } - for _, t := range c.R.LinExp { + for _, t := range c.R { accumulate(&B[t.WireID()], t, &L) } - for _, t := range c.O.LinExp { + for _, t := range c.O { accumulate(&C[t.WireID()], t, &L) } @@ -491,10 +491,10 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { A := make([]bool, nbWires) B := make([]bool, nbWires) for _, c := range r1cs.Constraints { - for _, t := range c.L.LinExp { + for _, t := range c.L { A[t.WireID()] = true } - for _, t := range c.R.LinExp { + for _, t := range c.R { B[t.WireID()] = true } } diff --git a/internal/backend/bn254/plonk/plonk_test.go b/internal/backend/bn254/plonk/plonk_test.go index 0fc5ff68ab..5860b7b3a7 100644 --- a/internal/backend/bn254/plonk/plonk_test.go +++ b/internal/backend/bn254/plonk/plonk_test.go @@ -36,7 +36,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/plonk" + "github.com/consensys/gnark/frontend/cs/scs" ) //--------------------// @@ -62,7 +62,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit, *k circuit := refCircuit{ nbConstraints: nbConstraints, } - ccs, err := frontend.Compile(curve.ID, backend.UNKNOWN, &circuit, frontend.WithBuilder(plonk.NewBuilder)) + ccs, err := frontend.Compile(curve.ID, scs.NewCompiler, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bw6-633/cs/r1cs.go b/internal/backend/bw6-633/cs/r1cs.go index 7101b1a5b5..ea70db009c 100644 --- a/internal/backend/bw6-633/cs/r1cs.go +++ b/internal/backend/bw6-633/cs/r1cs.go @@ -28,8 +28,8 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc" @@ -302,15 +302,15 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr. return nil } - if err := processLExp(r.L.LinExp, a, 1); err != nil { + if err := processLExp(r.L, a, 1); err != nil { return err } - if err := processLExp(r.R.LinExp, b, 2); err != nil { + if err := processLExp(r.R, b, 2); err != nil { return err } - if err := processLExp(r.O.LinExp, c, 3); err != nil { + if err := processLExp(r.O, c, 3); err != nil { return err } @@ -387,11 +387,11 @@ func (cs *R1CS) GetConstraints() [][]string { return r } -func (cs *R1CS) vtoString(l compiled.Variable) string { +func (cs *R1CS) vtoString(l compiled.LinearExpression) string { var sbb strings.Builder - for i := 0; i < len(l.LinExp); i++ { - cs.termToString(l.LinExp[i], &sbb) - if i+1 < len(l.LinExp) { + for i := 0; i < len(l); i++ { + cs.termToString(l[i], &sbb) + if i+1 < len(l) { sbb.WriteString(" + ") } } diff --git a/internal/backend/bw6-633/cs/r1cs_sparse.go b/internal/backend/bw6-633/cs/r1cs_sparse.go index 7cc2ccfad0..196842fa42 100644 --- a/internal/backend/bw6-633/cs/r1cs_sparse.go +++ b/internal/backend/bw6-633/cs/r1cs_sparse.go @@ -30,8 +30,8 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" diff --git a/internal/backend/bw6-633/cs/r1cs_test.go b/internal/backend/bw6-633/cs/r1cs_test.go index 5900e2b650..5287373c76 100644 --- a/internal/backend/bw6-633/cs/r1cs_test.go +++ b/internal/backend/bw6-633/cs/r1cs_test.go @@ -19,7 +19,6 @@ package cs_test import ( "bytes" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/internal/backend/circuits" @@ -37,7 +36,7 @@ func TestSerialization(t *testing.T) { t.Run(name, func(t *testing.T) { tc := circuits.Circuits[name] - r1cs1, err := frontend.Compile(ecc.BW6_633, backend.UNKNOWN, tc.Circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs1, err := frontend.Compile(ecc.BW6_633, r1cs.NewCompiler, tc.Circuit) if err != nil { t.Fatal(err) } @@ -46,7 +45,7 @@ func TestSerialization(t *testing.T) { } // copmpile a second time to ensure determinism - r1cs2, err := frontend.Compile(ecc.BW6_633, backend.UNKNOWN, tc.Circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs2, err := frontend.Compile(ecc.BW6_633, r1cs.NewCompiler, tc.Circuit) if err != nil { t.Fatal(err) } @@ -135,7 +134,7 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { var c circuit - ccs, err := frontend.Compile(ecc.BW6_633, backend.UNKNOWN, &c, frontend.WithBuilder(r1cs.NewBuilder)) + ccs, err := frontend.Compile(ecc.BW6_633, r1cs.NewCompiler, &c) if err != nil { b.Fatal(err) } diff --git a/internal/backend/bw6-633/cs/solution.go b/internal/backend/bw6-633/cs/solution.go index 4b611e7f07..06247264c2 100644 --- a/internal/backend/bw6-633/cs/solution.go +++ b/internal/backend/bw6-633/cs/solution.go @@ -24,8 +24,8 @@ import ( "sync/atomic" "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" @@ -148,7 +148,7 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { // tmp IO big int memory nbInputs := len(h.Inputs) - nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs)) + nbOutputs := len(h.Wires) // m := len(s.tmpHintsIO) // if m < (nbInputs + nbOutputs) { // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...) @@ -170,9 +170,6 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { for i := 0; i < nbInputs; i++ { switch t := h.Inputs[i].(type) { - case compiled.Variable: - v := s.computeLinearExpression(t.LinExp) - v.ToBigIntRegular(inputs[i]) case compiled.LinearExpression: v := s.computeLinearExpression(t) v.ToBigIntRegular(inputs[i]) diff --git a/internal/backend/bw6-633/groth16/groth16_test.go b/internal/backend/bw6-633/groth16/groth16_test.go index deb017f777..913eefffec 100644 --- a/internal/backend/bw6-633/groth16/groth16_test.go +++ b/internal/backend/bw6-633/groth16/groth16_test.go @@ -58,7 +58,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit) { circuit := refCircuit{ nbConstraints: nbConstraints, } - r1cs, err := frontend.Compile(curve.ID, backend.UNKNOWN, &circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs, err := frontend.Compile(curve.ID, r1cs.NewCompiler, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bw6-633/groth16/setup.go b/internal/backend/bw6-633/groth16/setup.go index b26489abbc..27e2db3700 100644 --- a/internal/backend/bw6-633/groth16/setup.go +++ b/internal/backend/bw6-633/groth16/setup.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" - "github.com/consensys/gnark/internal/backend/compiled" + "github.com/consensys/gnark/frontend/compiled" "math/big" "math/bits" ) @@ -336,13 +336,13 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // A, B or C at the indice of the variable for i, c := range r1cs.Constraints { - for _, t := range c.L.LinExp { + for _, t := range c.L { accumulate(&A[t.WireID()], t, &L) } - for _, t := range c.R.LinExp { + for _, t := range c.R { accumulate(&B[t.WireID()], t, &L) } - for _, t := range c.O.LinExp { + for _, t := range c.O { accumulate(&C[t.WireID()], t, &L) } @@ -491,10 +491,10 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { A := make([]bool, nbWires) B := make([]bool, nbWires) for _, c := range r1cs.Constraints { - for _, t := range c.L.LinExp { + for _, t := range c.L { A[t.WireID()] = true } - for _, t := range c.R.LinExp { + for _, t := range c.R { B[t.WireID()] = true } } diff --git a/internal/backend/bw6-633/plonk/plonk_test.go b/internal/backend/bw6-633/plonk/plonk_test.go index fa8b550794..b534add524 100644 --- a/internal/backend/bw6-633/plonk/plonk_test.go +++ b/internal/backend/bw6-633/plonk/plonk_test.go @@ -36,7 +36,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/plonk" + "github.com/consensys/gnark/frontend/cs/scs" ) //--------------------// @@ -62,7 +62,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit, *k circuit := refCircuit{ nbConstraints: nbConstraints, } - ccs, err := frontend.Compile(curve.ID, backend.UNKNOWN, &circuit, frontend.WithBuilder(plonk.NewBuilder)) + ccs, err := frontend.Compile(curve.ID, scs.NewCompiler, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bw6-761/cs/r1cs.go b/internal/backend/bw6-761/cs/r1cs.go index 097d53581c..bb29df9a55 100644 --- a/internal/backend/bw6-761/cs/r1cs.go +++ b/internal/backend/bw6-761/cs/r1cs.go @@ -28,8 +28,8 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc" @@ -302,15 +302,15 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr. return nil } - if err := processLExp(r.L.LinExp, a, 1); err != nil { + if err := processLExp(r.L, a, 1); err != nil { return err } - if err := processLExp(r.R.LinExp, b, 2); err != nil { + if err := processLExp(r.R, b, 2); err != nil { return err } - if err := processLExp(r.O.LinExp, c, 3); err != nil { + if err := processLExp(r.O, c, 3); err != nil { return err } @@ -387,11 +387,11 @@ func (cs *R1CS) GetConstraints() [][]string { return r } -func (cs *R1CS) vtoString(l compiled.Variable) string { +func (cs *R1CS) vtoString(l compiled.LinearExpression) string { var sbb strings.Builder - for i := 0; i < len(l.LinExp); i++ { - cs.termToString(l.LinExp[i], &sbb) - if i+1 < len(l.LinExp) { + for i := 0; i < len(l); i++ { + cs.termToString(l[i], &sbb) + if i+1 < len(l) { sbb.WriteString(" + ") } } diff --git a/internal/backend/bw6-761/cs/r1cs_sparse.go b/internal/backend/bw6-761/cs/r1cs_sparse.go index 8cc16bcf4c..68c01f59f3 100644 --- a/internal/backend/bw6-761/cs/r1cs_sparse.go +++ b/internal/backend/bw6-761/cs/r1cs_sparse.go @@ -30,8 +30,8 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" diff --git a/internal/backend/bw6-761/cs/r1cs_test.go b/internal/backend/bw6-761/cs/r1cs_test.go index 94eec2486c..b17a3f0f16 100644 --- a/internal/backend/bw6-761/cs/r1cs_test.go +++ b/internal/backend/bw6-761/cs/r1cs_test.go @@ -19,7 +19,6 @@ package cs_test import ( "bytes" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/internal/backend/circuits" @@ -41,7 +40,7 @@ func TestSerialization(t *testing.T) { return } - r1cs1, err := frontend.Compile(ecc.BW6_761, backend.UNKNOWN, tc.Circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs1, err := frontend.Compile(ecc.BW6_761, r1cs.NewCompiler, tc.Circuit) if err != nil { t.Fatal(err) } @@ -50,7 +49,7 @@ func TestSerialization(t *testing.T) { } // copmpile a second time to ensure determinism - r1cs2, err := frontend.Compile(ecc.BW6_761, backend.UNKNOWN, tc.Circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs2, err := frontend.Compile(ecc.BW6_761, r1cs.NewCompiler, tc.Circuit) if err != nil { t.Fatal(err) } @@ -139,7 +138,7 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { var c circuit - ccs, err := frontend.Compile(ecc.BW6_761, backend.UNKNOWN, &c, frontend.WithBuilder(r1cs.NewBuilder)) + ccs, err := frontend.Compile(ecc.BW6_761, r1cs.NewCompiler, &c) if err != nil { b.Fatal(err) } diff --git a/internal/backend/bw6-761/cs/solution.go b/internal/backend/bw6-761/cs/solution.go index fb1e1a19bd..d0b6dcf512 100644 --- a/internal/backend/bw6-761/cs/solution.go +++ b/internal/backend/bw6-761/cs/solution.go @@ -24,8 +24,8 @@ import ( "sync/atomic" "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" - "github.com/consensys/gnark/internal/backend/compiled" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" @@ -148,7 +148,7 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { // tmp IO big int memory nbInputs := len(h.Inputs) - nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs)) + nbOutputs := len(h.Wires) // m := len(s.tmpHintsIO) // if m < (nbInputs + nbOutputs) { // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...) @@ -170,9 +170,6 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { for i := 0; i < nbInputs; i++ { switch t := h.Inputs[i].(type) { - case compiled.Variable: - v := s.computeLinearExpression(t.LinExp) - v.ToBigIntRegular(inputs[i]) case compiled.LinearExpression: v := s.computeLinearExpression(t) v.ToBigIntRegular(inputs[i]) diff --git a/internal/backend/bw6-761/groth16/groth16_test.go b/internal/backend/bw6-761/groth16/groth16_test.go index 6951cd4a76..2077099253 100644 --- a/internal/backend/bw6-761/groth16/groth16_test.go +++ b/internal/backend/bw6-761/groth16/groth16_test.go @@ -58,7 +58,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit) { circuit := refCircuit{ nbConstraints: nbConstraints, } - r1cs, err := frontend.Compile(curve.ID, backend.UNKNOWN, &circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs, err := frontend.Compile(curve.ID, r1cs.NewCompiler, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/bw6-761/groth16/setup.go b/internal/backend/bw6-761/groth16/setup.go index 0e100d3319..1f8d2d84ad 100644 --- a/internal/backend/bw6-761/groth16/setup.go +++ b/internal/backend/bw6-761/groth16/setup.go @@ -25,7 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" - "github.com/consensys/gnark/internal/backend/compiled" + "github.com/consensys/gnark/frontend/compiled" "math/big" "math/bits" ) @@ -336,13 +336,13 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // A, B or C at the indice of the variable for i, c := range r1cs.Constraints { - for _, t := range c.L.LinExp { + for _, t := range c.L { accumulate(&A[t.WireID()], t, &L) } - for _, t := range c.R.LinExp { + for _, t := range c.R { accumulate(&B[t.WireID()], t, &L) } - for _, t := range c.O.LinExp { + for _, t := range c.O { accumulate(&C[t.WireID()], t, &L) } @@ -491,10 +491,10 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { A := make([]bool, nbWires) B := make([]bool, nbWires) for _, c := range r1cs.Constraints { - for _, t := range c.L.LinExp { + for _, t := range c.L { A[t.WireID()] = true } - for _, t := range c.R.LinExp { + for _, t := range c.R { B[t.WireID()] = true } } diff --git a/internal/backend/bw6-761/plonk/plonk_test.go b/internal/backend/bw6-761/plonk/plonk_test.go index d4a25ddaf0..7e9b4fafd9 100644 --- a/internal/backend/bw6-761/plonk/plonk_test.go +++ b/internal/backend/bw6-761/plonk/plonk_test.go @@ -36,7 +36,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/plonk" + "github.com/consensys/gnark/frontend/cs/scs" ) //--------------------// @@ -62,7 +62,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit, *k circuit := refCircuit{ nbConstraints: nbConstraints, } - ccs, err := frontend.Compile(curve.ID, backend.UNKNOWN, &circuit, frontend.WithBuilder(plonk.NewBuilder)) + ccs, err := frontend.Compile(curve.ID, scs.NewCompiler, &circuit) if err != nil { panic(err) } diff --git a/internal/backend/circuits/hint.go b/internal/backend/circuits/hint.go index d6eec90424..3b9e99294e 100644 --- a/internal/backend/circuits/hint.go +++ b/internal/backend/circuits/hint.go @@ -14,7 +14,7 @@ type hintCircuit struct { } func (circuit *hintCircuit) Define(api frontend.API) error { - res, err := api.NewHint(mulBy7, circuit.A) + res, err := api.Compiler().NewHint(mulBy7, 1, circuit.A) if err != nil { return fmt.Errorf("mulBy7: %w", err) } @@ -23,7 +23,7 @@ func (circuit *hintCircuit) Define(api frontend.API) error { api.AssertIsEqual(a7, _a7) api.AssertIsEqual(a7, circuit.B) - res, err = api.NewHint(make3) + res, err = api.Compiler().NewHint(make3, 1) if err != nil { return fmt.Errorf("make3: %w", err) } @@ -39,7 +39,7 @@ type vectorDoubleCircuit struct { } func (c *vectorDoubleCircuit) Define(api frontend.API) error { - res, err := api.NewHint(dvHint, c.A...) + res, err := api.Compiler().NewHint(dvHint, len(c.B), c.A...) if err != nil { return fmt.Errorf("double newhint: %w", err) } @@ -101,12 +101,12 @@ func init() { var mulBy7 = hint.NewStaticHint(func(curveID ecc.ID, inputs []*big.Int, result []*big.Int) error { result[0].Mul(inputs[0], big.NewInt(7)).Mod(result[0], curveID.Info().Fr.Modulus()) return nil -}, 1, 1) +}) var make3 = hint.NewStaticHint(func(curveID ecc.ID, inputs []*big.Int, result []*big.Int) error { result[0].SetUint64(3) return nil -}, 0, 1) +}) var dvHint = &doubleVector{} diff --git a/internal/backend/compiled/html.go b/internal/backend/compiled/html.go deleted file mode 100644 index b1335c396b..0000000000 --- a/internal/backend/compiled/html.go +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright 2020 ConsenSys AG -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package compiled - -const R1CSTemplate = ` - - - - - - - R1CS - - - - - - -
-

R1CS

- {{ $nbHints := len .MHints }} - {{ $nbConstraints := len .Constraints}} - {{.NbInternalVariables}} internal (includes {{$nbHints}} hints)
- {{.NbPublicVariables}} public
- {{.NbSecretVariables}} secret
- {{$nbConstraints}} constraints
-

L * R == O

-

-

-
- - - - - - - - - - - {{- range $i, $c := .Constraints}} - - - - - - - {{- end }} - -
#LRO
{{$i}} {{ toHTML $c.L $.Coefficients $.MHints}} {{ toHTML $c.R $.Coefficients $.MHints}} {{ toHTML $c.O $.Coefficients $.MHints}}
- - -` - -const SparseR1CSTemplate = ` - - - - - - - SparseR1CS - - - - - - -
-

SparseR1CS

- {{ $nbHints := len .MHints }} - {{ $nbConstraints := len .Constraints}} - - {{.NbInternalVariables}} internal (includes {{$nbHints}} hints)
- {{.NbPublicVariables}} public
- {{.NbSecretVariables}} secret
- {{$nbConstraints}} constraints
-

L + R + M0*M1 + O + k == 0

-

all variable id are offseted by 1 to match R1CS

-
- - - - - - - - - - - - - - - {{- range $i, $c := .Constraints}} - - - - - - - - - - {{- end }} - -
#LRM0M1Ok
{{$i}} {{ toHTML $c.L $.Coefficients $.MHints}} {{ toHTML $c.R $.Coefficients $.MHints}} {{ toHTML (index $c.M 0) $.Coefficients $.MHints}} {{ toHTML (index $c.M 1) $.Coefficients $.MHints}} {{ toHTML $c.O $.Coefficients $.MHints}} {{ toHTMLCoeff $c.K $.Coefficients }}
- - -` diff --git a/internal/backend/compiled/r1cs.go b/internal/backend/compiled/r1cs.go deleted file mode 100644 index 8831f6da01..0000000000 --- a/internal/backend/compiled/r1cs.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2020 ConsenSys AG -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package compiled - -// R1CS decsribes a set of R1C constraint -type R1CS struct { - CS - Constraints []R1C -} - -// GetNbConstraints returns the number of constraints -func (r1cs *R1CS) GetNbConstraints() int { - return len(r1cs.Constraints) -} diff --git a/internal/backend/compiled/r1cs_sparse.go b/internal/backend/compiled/r1cs_sparse.go deleted file mode 100644 index bcd59d5994..0000000000 --- a/internal/backend/compiled/r1cs_sparse.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2020 ConsenSys AG -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package compiled - -// R1CS decsribes a set of SparseR1C constraint -type SparseR1CS struct { - CS - Constraints []SparseR1C -} - -// GetNbConstraints returns the number of constraints -func (cs *SparseR1CS) GetNbConstraints() int { - return len(cs.Constraints) -} diff --git a/internal/backend/compiled/variable.go b/internal/backend/compiled/variable.go deleted file mode 100644 index 3c6d499055..0000000000 --- a/internal/backend/compiled/variable.go +++ /dev/null @@ -1,112 +0,0 @@ -// Copyright 2021 ConsenSys AG -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package compiled - -import ( - "errors" - "math/big" - "strings" - - "github.com/consensys/gnark/frontend/schema" -) - -// errNoValue triggered when trying to access a variable that was not allocated -var errNoValue = errors.New("can't determine API input value") - -type LinearExpression []Term - -// Variable represent a linear expression of wires -type Variable struct { - LinExp LinearExpression - IsBoolean *bool -} - -// Clone returns a copy of the underlying slice -func (v Variable) Clone() Variable { - var res Variable - res.IsBoolean = v.IsBoolean - res.LinExp = make([]Term, len(v.LinExp)) - copy(res.LinExp, v.LinExp) - return res -} - -// Len return the lenght of the Variable (implements Sort interface) -func (v LinearExpression) Len() int { - return len(v) -} - -// Equals returns true if both SORTED expressions are the same -// -// pre conditions: l and o are sorted -func (v LinearExpression) Equal(o LinearExpression) bool { - if len(v) != len(o) { - return false - } - if (v == nil) != (o == nil) { - return false - } - for i := 0; i < len(v); i++ { - if v[i] != o[i] { - return false - } - } - return true -} - -// Swap swaps terms in the Variable (implements Sort interface) -func (v LinearExpression) Swap(i, j int) { - v[i], v[j] = v[j], v[i] -} - -// Less returns true if variableID for term at i is less than variableID for term at j (implements Sort interface) -func (v LinearExpression) Less(i, j int) bool { - _, iID, iVis := v[i].Unpack() - _, jID, jVis := v[j].Unpack() - if iVis == jVis { - return iID < jID - } - return iVis > jVis -} - -func (v Variable) string(sbb *strings.Builder, coeffs []big.Int) { - for i := 0; i < len(v.LinExp); i++ { - v.LinExp[i].string(sbb, coeffs) - if i+1 < len(v.LinExp) { - sbb.WriteString(" + ") - } - } -} - -// assertIsSet panics if the variable is unset -// this may happen if inside a Define we have -// var a variable -// cs.Mul(a, 1) -// since a was not in the circuit struct it is not a secret variable -func (v Variable) AssertIsSet() { - - if len(v.LinExp) == 0 { - panic(errNoValue) - } - -} - -// isConstant returns true if the variable is ONE_WIRE * coeff -func (v *Variable) IsConstant() bool { - if len(v.LinExp) != 1 { - return false - } - _, vID, visibility := v.LinExp[0].Unpack() - return vID == 0 && visibility == schema.Public -} diff --git a/internal/generator/backend/template/representations/r1cs.go.tmpl b/internal/generator/backend/template/representations/r1cs.go.tmpl index ee3b99f846..09b1518a13 100644 --- a/internal/generator/backend/template/representations/r1cs.go.tmpl +++ b/internal/generator/backend/template/representations/r1cs.go.tmpl @@ -9,7 +9,7 @@ import ( "github.com/fxamacker/cbor/v2" "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/internal/backend/compiled" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/frontend/schema" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" @@ -292,15 +292,15 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a,b,c *fr.El return nil } - if err := processLExp(r.L.LinExp, a, 1); err != nil { + if err := processLExp(r.L, a, 1); err != nil { return err } - if err := processLExp(r.R.LinExp, b, 2); err != nil { + if err := processLExp(r.R, b, 2); err != nil { return err } - if err := processLExp(r.O.LinExp, c, 3); err != nil { + if err := processLExp(r.O, c, 3); err != nil { return err } @@ -379,11 +379,11 @@ func (cs *R1CS) GetConstraints() [][]string { return r } -func (cs *R1CS) vtoString(l compiled.Variable) string { +func (cs *R1CS) vtoString(l compiled.LinearExpression) string { var sbb strings.Builder - for i := 0; i < len(l.LinExp); i++ { - cs.termToString(l.LinExp[i], &sbb) - if i+1 < len(l.LinExp) { + for i := 0; i < len(l); i++ { + cs.termToString(l[i], &sbb) + if i+1 < len(l) { sbb.WriteString(" + ") } } diff --git a/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl b/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl index dbc9915d2c..0a220f3a4d 100644 --- a/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl +++ b/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl @@ -11,7 +11,7 @@ import ( "math" "github.com/consensys/gnark/internal/backend/ioutils" - "github.com/consensys/gnark/internal/backend/compiled" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend/schema" "github.com/consensys/gnark/backend/witness" diff --git a/internal/generator/backend/template/representations/solution.go.tmpl b/internal/generator/backend/template/representations/solution.go.tmpl index dba02f7cb6..98cca89b1a 100644 --- a/internal/generator/backend/template/representations/solution.go.tmpl +++ b/internal/generator/backend/template/representations/solution.go.tmpl @@ -6,7 +6,7 @@ import ( "sync/atomic" "github.com/consensys/gnark/backend/hint" - "github.com/consensys/gnark/internal/backend/compiled" + "github.com/consensys/gnark/frontend/compiled" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/frontend/schema" @@ -129,7 +129,7 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { // tmp IO big int memory nbInputs := len(h.Inputs) - nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs)) + nbOutputs := len(h.Wires) // m := len(s.tmpHintsIO) // if m < (nbInputs + nbOutputs) { // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...) @@ -151,9 +151,6 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { for i := 0; i < nbInputs; i++ { switch t := h.Inputs[i].(type) { - case compiled.Variable: - v := s.computeLinearExpression(t.LinExp) - v.ToBigIntRegular(inputs[i]) case compiled.LinearExpression: v := s.computeLinearExpression(t) v.ToBigIntRegular(inputs[i]) diff --git a/internal/generator/backend/template/representations/tests/r1cs.go.tmpl b/internal/generator/backend/template/representations/tests/r1cs.go.tmpl index 852a1583fe..cb399f03e5 100644 --- a/internal/generator/backend/template/representations/tests/r1cs.go.tmpl +++ b/internal/generator/backend/template/representations/tests/r1cs.go.tmpl @@ -3,7 +3,6 @@ import ( "bytes" "testing" "reflect" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/internal/backend/circuits" @@ -25,7 +24,7 @@ func TestSerialization(t *testing.T) { } {{end}} - r1cs1, err := frontend.Compile(ecc.{{ .CurveID }}, backend.UNKNOWN, tc.Circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs1, err := frontend.Compile(ecc.{{ .CurveID }}, r1cs.NewCompiler, tc.Circuit) if err != nil { t.Fatal(err) } @@ -34,7 +33,7 @@ func TestSerialization(t *testing.T) { } // copmpile a second time to ensure determinism - r1cs2, err := frontend.Compile(ecc.{{ .CurveID }}, backend.UNKNOWN, tc.Circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs2, err := frontend.Compile(ecc.{{ .CurveID }},r1cs.NewCompiler, tc.Circuit) if err != nil { t.Fatal(err) } @@ -124,7 +123,7 @@ func (circuit *circuit) Define(api frontend.API) error { func BenchmarkSolve(b *testing.B) { var c circuit - ccs, err := frontend.Compile(ecc.{{ .CurveID }}, backend.UNKNOWN, &c, frontend.WithBuilder(r1cs.NewBuilder)) + ccs, err := frontend.Compile(ecc.{{ .CurveID }},r1cs.NewCompiler, &c) if err != nil { b.Fatal(err) } diff --git a/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl index 14933f3931..3a9719888e 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl @@ -4,7 +4,7 @@ import ( {{ template "import_backend_cs" . }} {{ template "import_fft" . }} "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/internal/backend/compiled" + "github.com/consensys/gnark/frontend/compiled" "math/big" "math/bits" ) @@ -315,13 +315,13 @@ func setupABC(r1cs *cs.R1CS, domain *fft.Domain, toxicWaste toxicWaste) (A []fr. // A, B or C at the indice of the variable for i, c := range r1cs.Constraints { - for _, t := range c.L.LinExp { + for _, t := range c.L { accumulate(&A[t.WireID()], t, &L) } - for _, t := range c.R.LinExp { + for _, t := range c.R { accumulate(&B[t.WireID()], t, &L) } - for _, t := range c.O.LinExp { + for _, t := range c.O { accumulate(&C[t.WireID()], t, &L) } @@ -470,10 +470,10 @@ func dummyInfinityCount(r1cs *cs.R1CS) (nbZeroesA, nbZeroesB int) { A := make([]bool, nbWires) B := make([]bool, nbWires) for _, c := range r1cs.Constraints { - for _, t := range c.L.LinExp { + for _, t := range c.L { A[t.WireID()] = true } - for _, t := range c.R.LinExp { + for _, t := range c.R { B[t.WireID()] = true } } diff --git a/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.go.tmpl index b83cda5236..4b0abd9e7c 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.go.tmpl @@ -38,7 +38,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit) { circuit := refCircuit{ nbConstraints: nbConstraints, } - r1cs, err := frontend.Compile(curve.ID, backend.UNKNOWN, &circuit, frontend.WithBuilder(r1cs.NewBuilder)) + r1cs, err := frontend.Compile(curve.ID,r1cs.NewCompiler, &circuit) if err != nil { panic(err) } diff --git a/internal/generator/backend/template/zkpschemes/plonk/tests/plonk.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/tests/plonk.go.tmpl index e378d356f6..d53aff5dd5 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/tests/plonk.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/tests/plonk.go.tmpl @@ -10,10 +10,10 @@ import ( "testing" "reflect" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/plonk" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/frontend/cs/scs" ) @@ -43,7 +43,7 @@ func referenceCircuit() (frontend.CompiledConstraintSystem, frontend.Circuit, *k circuit := refCircuit{ nbConstraints: nbConstraints, } - ccs, err := frontend.Compile(curve.ID, backend.UNKNOWN, &circuit, frontend.WithBuilder(plonk.NewBuilder)) + ccs, err := frontend.Compile(curve.ID, scs.NewCompiler, &circuit) if err != nil { panic(err) } diff --git a/internal/utils/circuit.go b/internal/utils/circuit.go index 2c2c4f8a1d..d4b585bf78 100644 --- a/internal/utils/circuit.go +++ b/internal/utils/circuit.go @@ -1,67 +1 @@ package utils - -import ( - "fmt" - "reflect" - - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/schema" -) - -// ShallowClone clones given circuit -// this is actually a shallow copy → if the circuits contains maps or slices -// only the reference is copied. -func ShallowClone(circuit frontend.Circuit) frontend.Circuit { - - cValue := reflect.ValueOf(circuit).Elem() - newCircuit := reflect.New(cValue.Type()) - newCircuit.Elem().Set(cValue) - - circuitCopy, ok := newCircuit.Interface().(frontend.Circuit) - if !ok { - panic("couldn't clone the circuit") - } - - if !reflect.DeepEqual(circuitCopy, circuit) { - panic("clone failed") - } - - return circuitCopy -} - -func CopyWitness(to, from frontend.Circuit) { - var wValues []interface{} - - var collectHandler schema.LeafHandler = func(visibility schema.Visibility, name string, tInput reflect.Value) error { - v := tInput.Interface().(frontend.Variable) - - if visibility == schema.Secret || visibility == schema.Public { - if v == nil { - return fmt.Errorf("when parsing variable %s: missing assignment", name) - } - wValues = append(wValues, v) - } - return nil - } - if _, err := schema.Parse(from, tVariable, collectHandler); err != nil { - panic(err) - } - - i := 0 - var setHandler schema.LeafHandler = func(visibility schema.Visibility, name string, tInput reflect.Value) error { - if visibility == schema.Secret || visibility == schema.Public { - tInput.Set(reflect.ValueOf((wValues[i]))) - i++ - } - return nil - } - // this can't error. - _, _ = schema.Parse(to, tVariable, setHandler) - -} - -var tVariable reflect.Type - -func init() { - tVariable = reflect.ValueOf(struct{ A frontend.Variable }{}).FieldByName("A").Type() -} diff --git a/std/algebra/fields_bls24315/e24_test.go b/std/algebra/fields_bls24315/e24_test.go index 526ea0828e..38ce758adb 100644 --- a/std/algebra/fields_bls24315/e24_test.go +++ b/std/algebra/fields_bls24315/e24_test.go @@ -21,8 +21,8 @@ import ( "github.com/consensys/gnark-crypto/ecc" bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/test" ) @@ -414,7 +414,7 @@ func BenchmarkMulE24(b *testing.B) { var c fp24Mul b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_633, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewCompiler, &c) } }) @@ -425,7 +425,7 @@ func BenchmarkSquareE24(b *testing.B) { var c fp24Square b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_633, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewCompiler, &c) } }) @@ -436,7 +436,7 @@ func BenchmarkInverseE24(b *testing.B) { var c fp24Inverse b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_633, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewCompiler, &c) } }) @@ -447,7 +447,7 @@ func BenchmarkConjugateE24(b *testing.B) { var c fp24Conjugate b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_633, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewCompiler, &c) } }) @@ -458,7 +458,7 @@ func BenchmarkMulBy034E24(b *testing.B) { var c fp24MulBy034 b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_633, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewCompiler, &c) } }) diff --git a/std/algebra/sw_bls12377/g1.go b/std/algebra/sw_bls12377/g1.go index 7494b8d594..150722b912 100644 --- a/std/algebra/sw_bls12377/g1.go +++ b/std/algebra/sw_bls12377/g1.go @@ -199,8 +199,8 @@ func (p *G1Affine) Double(api frontend.API, p1 G1Affine) *G1Affine { // then the compiled circuit depends on s. If it is variable type, then // the circuit is independent of the inputs. func (P *G1Affine) ScalarMul(api frontend.API, Q G1Affine, s interface{}) *G1Affine { - if api.IsConstant(s) { - return P.constScalarMul(api, Q, api.ConstantValue(s)) + if n, ok := api.Compiler().ConstantValue(s); ok { + return P.constScalarMul(api, Q, n) } else { return P.varScalarMul(api, Q, s) } @@ -225,7 +225,7 @@ var scalarDecompositionHintBLS12377 = hint.NewStaticHint(func(curve ecc.ID, inpu res[2].Div(res[2], cc.fr) return nil -}, 1, 3) +}) func init() { hint.Register(scalarDecompositionHintBLS12377) @@ -249,12 +249,12 @@ func (P *G1Affine) varScalarMul(api frontend.API, Q G1Affine, s frontend.Variabl // points and the operations on the points are performed on the `inner` // curve of the outer curve. We require some parameters from the inner // curve. - cc := innerCurve(api.Curve()) + cc := innerCurve(api.Compiler().Curve()) // the hints allow to decompose the scalar s into s1 and s2 such that // s1 + λ * s2 == s mod r, // where λ is third root of one in 𝔽_r. - sd, err := api.NewHint(scalarDecompositionHintBLS12377, s) + sd, err := api.Compiler().NewHint(scalarDecompositionHintBLS12377, 3, s) if err != nil { // err is non-nil only for invalid number of inputs panic(err) @@ -343,7 +343,7 @@ func (P *G1Affine) constScalarMul(api frontend.API, Q G1Affine, s *big.Int) *G1A // bits are constant and here it makes sense to use the table in the main // loop. var Acc, negQ, negPhiQ, phiQ G1Affine - cc := innerCurve(api.Curve()) + cc := innerCurve(api.Compiler().Curve()) s.Mod(s, cc.fr) cc.phi(api, &phiQ, &Q) diff --git a/std/algebra/sw_bls12377/g1_test.go b/std/algebra/sw_bls12377/g1_test.go index e1f696ead9..878aa9f3bb 100644 --- a/std/algebra/sw_bls12377/g1_test.go +++ b/std/algebra/sw_bls12377/g1_test.go @@ -22,8 +22,9 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/test" bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" @@ -391,7 +392,7 @@ func BenchmarkConstScalarMulG1(b *testing.B) { c.R = r b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_761, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_761, r1cs.NewCompiler, &c) } }) @@ -399,7 +400,7 @@ func BenchmarkConstScalarMulG1(b *testing.B) { b.Run("plonk", func(b *testing.B) { var err error for i := 0; i < b.N; i++ { - ccsBench, err = frontend.Compile(ecc.BW6_761, backend.PLONK, &c) + ccsBench, err = frontend.Compile(ecc.BW6_761, scs.NewCompiler, &c) if err != nil { b.Fatal(err) } @@ -420,7 +421,7 @@ func BenchmarkVarScalarMulG1(b *testing.B) { c.R = r b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_761, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_761, r1cs.NewCompiler, &c) } }) @@ -428,7 +429,7 @@ func BenchmarkVarScalarMulG1(b *testing.B) { b.Run("plonk", func(b *testing.B) { var err error for i := 0; i < b.N; i++ { - ccsBench, err = frontend.Compile(ecc.BW6_761, backend.PLONK, &c) + ccsBench, err = frontend.Compile(ecc.BW6_761, scs.NewCompiler, &c) if err != nil { b.Fatal(err) } diff --git a/std/algebra/sw_bls12377/g2_test.go b/std/algebra/sw_bls12377/g2_test.go index a32e0e145b..46cb6ed6d9 100644 --- a/std/algebra/sw_bls12377/g2_test.go +++ b/std/algebra/sw_bls12377/g2_test.go @@ -22,8 +22,8 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/std/algebra/fields_bls12377" "github.com/consensys/gnark/test" @@ -272,7 +272,7 @@ func BenchmarkDoubleAffineG2(b *testing.B) { var c g2DoubleAffine b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_761, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_761, r1cs.NewCompiler, &c) } }) @@ -283,7 +283,7 @@ func BenchmarkAddAssignAffineG2(b *testing.B) { var c g2AddAssignAffine b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_761, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_761, r1cs.NewCompiler, &c) } }) @@ -294,7 +294,7 @@ func BenchmarkDoubleAndAddAffineG2(b *testing.B) { var c g2DoubleAndAddAffine b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_761, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_761, r1cs.NewCompiler, &c) } }) diff --git a/std/algebra/sw_bls12377/pairing_test.go b/std/algebra/sw_bls12377/pairing_test.go index 81185f5494..9b88ff0bd8 100644 --- a/std/algebra/sw_bls12377/pairing_test.go +++ b/std/algebra/sw_bls12377/pairing_test.go @@ -25,6 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/std/algebra/fields_bls12377" "github.com/consensys/gnark/test" ) @@ -199,7 +200,7 @@ func BenchmarkPairing(b *testing.B) { var c pairingBLS377 b.ResetTimer() for i := 0; i < b.N; i++ { - frontend.Compile(ecc.BW6_761, backend.PLONK, &c) + frontend.Compile(ecc.BW6_761, scs.NewCompiler, &c) } // ccsBench, _ = compiler.Compile(ecc.BW6_761, backend.GROTH16, &c) // b.Log("groth16", ccsBench.GetNbConstraints()) diff --git a/std/algebra/sw_bls24315/g1.go b/std/algebra/sw_bls24315/g1.go index 4afde9a1cd..7bcd1d9c24 100644 --- a/std/algebra/sw_bls24315/g1.go +++ b/std/algebra/sw_bls24315/g1.go @@ -199,8 +199,8 @@ func (p *G1Affine) Double(api frontend.API, p1 G1Affine) *G1Affine { // then the compiled circuit depends on s. If it is variable type, then // the circuit is independent of the inputs. func (P *G1Affine) ScalarMul(api frontend.API, Q G1Affine, s interface{}) *G1Affine { - if api.IsConstant(s) { - return P.constScalarMul(api, Q, api.ConstantValue(s)) + if n, ok := api.Compiler().ConstantValue(s); ok { + return P.constScalarMul(api, Q, n) } else { return P.varScalarMul(api, Q, s) } @@ -225,7 +225,7 @@ var scalarDecompositionHintBLS24315 = hint.NewStaticHint(func(curve ecc.ID, inpu res[2].Div(res[2], cc.fr) return nil -}, 1, 3) +}) func init() { hint.Register(scalarDecompositionHintBLS24315) @@ -249,12 +249,12 @@ func (P *G1Affine) varScalarMul(api frontend.API, Q G1Affine, s frontend.Variabl // points and the operations on the points are performed on the `inner` // curve of the outer curve. We require some parameters from the inner // curve. - cc := innerCurve(api.Curve()) + cc := innerCurve(api.Compiler().Curve()) // the hints allow to decompose the scalar s into s1 and s2 such that // s1 + λ * s2 == s mod r, // where λ is third root of one in 𝔽_r. - sd, err := api.NewHint(scalarDecompositionHintBLS24315, s) + sd, err := api.Compiler().NewHint(scalarDecompositionHintBLS24315, 3, s) if err != nil { // err is non-nil only for invalid number of inputs panic(err) @@ -343,7 +343,7 @@ func (P *G1Affine) constScalarMul(api frontend.API, Q G1Affine, s *big.Int) *G1A // bits are constant and here it makes sense to use the table in the main // loop. var Acc, negQ, negPhiQ, phiQ G1Affine - cc := innerCurve(api.Curve()) + cc := innerCurve(api.Compiler().Curve()) s.Mod(s, cc.fr) cc.phi(api, &phiQ, &Q) diff --git a/std/algebra/sw_bls24315/g1_test.go b/std/algebra/sw_bls24315/g1_test.go index 663799eaac..577865a75e 100644 --- a/std/algebra/sw_bls24315/g1_test.go +++ b/std/algebra/sw_bls24315/g1_test.go @@ -22,8 +22,9 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/test" bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315" @@ -391,7 +392,7 @@ func BenchmarkConstScalarMulG1(b *testing.B) { c.R = r b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_633, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewCompiler, &c) } }) @@ -399,7 +400,7 @@ func BenchmarkConstScalarMulG1(b *testing.B) { b.Run("plonk", func(b *testing.B) { var err error for i := 0; i < b.N; i++ { - ccsBench, err = frontend.Compile(ecc.BW6_633, backend.PLONK, &c) + ccsBench, err = frontend.Compile(ecc.BW6_633, scs.NewCompiler, &c) if err != nil { b.Fatal(err) } @@ -420,7 +421,7 @@ func BenchmarkVarScalarMulG1(b *testing.B) { c.R = r b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_633, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewCompiler, &c) } }) @@ -428,7 +429,7 @@ func BenchmarkVarScalarMulG1(b *testing.B) { b.Run("plonk", func(b *testing.B) { var err error for i := 0; i < b.N; i++ { - ccsBench, err = frontend.Compile(ecc.BW6_633, backend.PLONK, &c) + ccsBench, err = frontend.Compile(ecc.BW6_633, scs.NewCompiler, &c) if err != nil { b.Fatal(err) } diff --git a/std/algebra/sw_bls24315/g2_test.go b/std/algebra/sw_bls24315/g2_test.go index af80416be3..4fb2c941bb 100644 --- a/std/algebra/sw_bls24315/g2_test.go +++ b/std/algebra/sw_bls24315/g2_test.go @@ -22,8 +22,8 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/std/algebra/fields_bls24315" "github.com/consensys/gnark/test" @@ -272,7 +272,7 @@ func BenchmarkDoubleAffineG2(b *testing.B) { var c g2DoubleAffine b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_633, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewCompiler, &c) } }) @@ -283,7 +283,7 @@ func BenchmarkAddAssignAffineG2(b *testing.B) { var c g2AddAssignAffine b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_633, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewCompiler, &c) } }) @@ -294,7 +294,7 @@ func BenchmarkDoubleAndAddAffineG2(b *testing.B) { var c g2DoubleAndAddAffine b.Run("groth16", func(b *testing.B) { for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BW6_633, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewCompiler, &c) } }) diff --git a/std/algebra/sw_bls24315/pairing_test.go b/std/algebra/sw_bls24315/pairing_test.go index a858c7c19b..ab3d16fd97 100644 --- a/std/algebra/sw_bls24315/pairing_test.go +++ b/std/algebra/sw_bls24315/pairing_test.go @@ -25,6 +25,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/std/algebra/fields_bls24315" "github.com/consensys/gnark/test" ) @@ -213,6 +214,6 @@ func TestTriplePairingBLS24315(t *testing.T) { func BenchmarkPairing(b *testing.B) { var c pairingBLS24315 - ccsBench, _ = frontend.Compile(ecc.BW6_633, backend.GROTH16, &c) + ccsBench, _ = frontend.Compile(ecc.BW6_633, r1cs.NewCompiler, &c) b.Log("groth16", ccsBench.GetNbConstraints()) } diff --git a/std/algebra/twistededwards/bandersnatch/point_test.go b/std/algebra/twistededwards/bandersnatch/point_test.go index ff879c2fa2..8b78d922ff 100644 --- a/std/algebra/twistededwards/bandersnatch/point_test.go +++ b/std/algebra/twistededwards/bandersnatch/point_test.go @@ -22,8 +22,8 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-381/twistededwards/bandersnatch" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/test" ) @@ -34,7 +34,7 @@ type mustBeOnCurve struct { func (circuit *mustBeOnCurve) Define(api frontend.API) error { // get edwards curve params - params, err := NewEdCurve(api.Curve()) + params, err := NewEdCurve(api.Compiler().Curve()) if err != nil { return err } @@ -69,7 +69,7 @@ type add struct { func (circuit *add) Define(api frontend.API) error { // get edwards curve params - params, err := NewEdCurve(api.Curve()) + params, err := NewEdCurve(api.Compiler().Curve()) if err != nil { return err } @@ -125,7 +125,7 @@ type addGeneric struct { func (circuit *addGeneric) Define(api frontend.API) error { // get edwards curve params - params, err := NewEdCurve(api.Curve()) + params, err := NewEdCurve(api.Compiler().Curve()) if err != nil { return err } @@ -180,7 +180,7 @@ type double struct { func (circuit *double) Define(api frontend.API) error { // get edwards curve params - params, err := NewEdCurve(api.Curve()) + params, err := NewEdCurve(api.Compiler().Curve()) if err != nil { return err } @@ -228,7 +228,7 @@ type scalarMulFixed struct { func (circuit *scalarMulFixed) Define(api frontend.API) error { // get edwards curve params - params, err := NewEdCurve(api.Curve()) + params, err := NewEdCurve(api.Compiler().Curve()) if err != nil { return err } @@ -279,7 +279,7 @@ type scalarMulGeneric struct { func (circuit *scalarMulGeneric) Define(api frontend.API) error { // get edwards curve params - params, err := NewEdCurve(api.Curve()) + params, err := NewEdCurve(api.Compiler().Curve()) if err != nil { return err } @@ -364,36 +364,36 @@ func TestNeg(t *testing.T) { // Bench func BenchmarkDouble(b *testing.B) { var c double - ccsBench, _ := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c) + ccsBench, _ := frontend.Compile(ecc.BLS12_381, r1cs.NewCompiler, &c) b.Log("groth16", ccsBench.GetNbConstraints()) } func BenchmarkAddGeneric(b *testing.B) { var c addGeneric - ccsBench, _ := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c) + ccsBench, _ := frontend.Compile(ecc.BLS12_381, r1cs.NewCompiler, &c) b.Log("groth16", ccsBench.GetNbConstraints()) } func BenchmarkAddFixedPoint(b *testing.B) { var c add - ccsBench, _ := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c) + ccsBench, _ := frontend.Compile(ecc.BLS12_381, r1cs.NewCompiler, &c) b.Log("groth16", ccsBench.GetNbConstraints()) } func BenchmarkMustBeOnCurve(b *testing.B) { var c mustBeOnCurve - ccsBench, _ := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c) + ccsBench, _ := frontend.Compile(ecc.BLS12_381, r1cs.NewCompiler, &c) b.Log("groth16", ccsBench.GetNbConstraints()) } func BenchmarkScalarMulGeneric(b *testing.B) { var c scalarMulGeneric - ccsBench, _ := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c) + ccsBench, _ := frontend.Compile(ecc.BLS12_381, r1cs.NewCompiler, &c) b.Log("groth16", ccsBench.GetNbConstraints()) } func BenchmarkScalarMulFixed(b *testing.B) { var c scalarMulFixed - ccsBench, _ := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c) + ccsBench, _ := frontend.Compile(ecc.BLS12_381, r1cs.NewCompiler, &c) b.Log("groth16", ccsBench.GetNbConstraints()) } diff --git a/std/algebra/twistededwards/point_test.go b/std/algebra/twistededwards/point_test.go index d9387e94fd..14dcb47b08 100644 --- a/std/algebra/twistededwards/point_test.go +++ b/std/algebra/twistededwards/point_test.go @@ -39,7 +39,7 @@ type mustBeOnCurve struct { func (circuit *mustBeOnCurve) Define(api frontend.API) error { // get edwards curve params - params, err := NewEdCurve(api.Curve()) + params, err := NewEdCurve(api.Compiler().Curve()) if err != nil { return err } @@ -74,7 +74,7 @@ type add struct { func (circuit *add) Define(api frontend.API) error { // get edwards curve params - params, err := NewEdCurve(api.Curve()) + params, err := NewEdCurve(api.Compiler().Curve()) if err != nil { return err } @@ -130,7 +130,7 @@ type addGeneric struct { func (circuit *addGeneric) Define(api frontend.API) error { // get edwards curve params - params, err := NewEdCurve(api.Curve()) + params, err := NewEdCurve(api.Compiler().Curve()) if err != nil { return err } @@ -271,7 +271,7 @@ type double struct { func (circuit *double) Define(api frontend.API) error { // get edwards curve params - params, err := NewEdCurve(api.Curve()) + params, err := NewEdCurve(api.Compiler().Curve()) if err != nil { return err } @@ -372,7 +372,7 @@ type scalarMulFixed struct { func (circuit *scalarMulFixed) Define(api frontend.API) error { // get edwards curve params - params, err := NewEdCurve(api.Curve()) + params, err := NewEdCurve(api.Compiler().Curve()) if err != nil { return err } @@ -473,7 +473,7 @@ type scalarMulGeneric struct { func (circuit *scalarMulGeneric) Define(api frontend.API) error { // get edwards curve params - params, err := NewEdCurve(api.Curve()) + params, err := NewEdCurve(api.Compiler().Curve()) if err != nil { return err } @@ -608,7 +608,7 @@ type doubleScalarMulGeneric struct { func (circuit *doubleScalarMulGeneric) Define(api frontend.API) error { // get edwards curve params - params, err := NewEdCurve(api.Curve()) + params, err := NewEdCurve(api.Compiler().Curve()) if err != nil { return err } diff --git a/std/fiat-shamir/transcript_test.go b/std/fiat-shamir/transcript_test.go index ab728cb05c..2911e84874 100644 --- a/std/fiat-shamir/transcript_test.go +++ b/std/fiat-shamir/transcript_test.go @@ -23,8 +23,8 @@ import ( "github.com/consensys/gnark-crypto/ecc" fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/hash" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/test" ) @@ -43,7 +43,7 @@ func (circuit *FiatShamirCircuit) Define(api frontend.API) error { } // get the challenges - alpha, beta, gamma := getChallenges(api.Curve()) + alpha, beta, gamma := getChallenges(api.Compiler().Curve()) // New transcript with 3 challenges to be derived tsSnark := NewTranscript(api, &hSnark, alpha, beta, gamma) @@ -156,7 +156,7 @@ func BenchmarkCompile(b *testing.B) { var ccs frontend.CompiledConstraintSystem b.ResetTimer() for i := 0; i < b.N; i++ { - ccs, _ = frontend.Compile(ecc.BN254, backend.PLONK, &circuit) + ccs, _ = frontend.Compile(ecc.BN254, scs.NewCompiler, &circuit) } b.Log(ccs.GetNbConstraints()) } diff --git a/std/groth16_bls12377/verifier_test.go b/std/groth16_bls12377/verifier_test.go index fc571f7115..f90dc0974a 100644 --- a/std/groth16_bls12377/verifier_test.go +++ b/std/groth16_bls12377/verifier_test.go @@ -24,6 +24,7 @@ import ( bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" backend_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/cs" groth16_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/groth16" "github.com/consensys/gnark/internal/backend/bls12-377/witness" @@ -61,7 +62,7 @@ func generateBls12377InnerProof(t *testing.T, vk *groth16_bls12377.VerifyingKey, // create a mock cs: knowing the preimage of a hash using mimc var circuit, w mimcCircuit - r1cs, err := frontend.Compile(ecc.BLS12_377, backend.GROTH16, &circuit) + r1cs, err := frontend.Compile(ecc.BLS12_377, r1cs.NewCompiler, &circuit) if err != nil { t.Fatal(err) } @@ -200,7 +201,7 @@ func BenchmarkCompile(b *testing.B) { var ccs frontend.CompiledConstraintSystem b.ResetTimer() for i := 0; i < b.N; i++ { - ccs, _ = frontend.Compile(ecc.BW6_761, backend.GROTH16, &circuit) + ccs, _ = frontend.Compile(ecc.BW6_761, r1cs.NewCompiler, &circuit) } b.Log(ccs.GetNbConstraints()) } diff --git a/std/groth16_bls24315/verifier_test.go b/std/groth16_bls24315/verifier_test.go index 95a71be3ec..dff6d1902b 100644 --- a/std/groth16_bls24315/verifier_test.go +++ b/std/groth16_bls24315/verifier_test.go @@ -24,6 +24,7 @@ import ( bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" backend_bls24315 "github.com/consensys/gnark/internal/backend/bls24-315/cs" groth16_bls24315 "github.com/consensys/gnark/internal/backend/bls24-315/groth16" "github.com/consensys/gnark/internal/backend/bls24-315/witness" @@ -62,7 +63,7 @@ func generateBls24315InnerProof(t *testing.T, vk *groth16_bls24315.VerifyingKey, // create a mock cs: knowing the preimage of a hash using mimc var circuit, w mimcCircuit - r1cs, err := frontend.Compile(ecc.BLS24_315, backend.GROTH16, &circuit) + r1cs, err := frontend.Compile(ecc.BLS24_315, r1cs.NewCompiler, &circuit) if err != nil { t.Fatal(err) } @@ -201,7 +202,7 @@ func BenchmarkCompile(b *testing.B) { var ccs frontend.CompiledConstraintSystem b.ResetTimer() for i := 0; i < b.N; i++ { - ccs, _ = frontend.Compile(ecc.BW6_633, backend.GROTH16, &circuit) + ccs, _ = frontend.Compile(ecc.BW6_633, r1cs.NewCompiler, &circuit) } b.Log(ccs.GetNbConstraints()) } diff --git a/std/hash/mimc/mimc.go b/std/hash/mimc/mimc.go index 90debe3d7e..36a9d07567 100644 --- a/std/hash/mimc/mimc.go +++ b/std/hash/mimc/mimc.go @@ -36,7 +36,7 @@ type MiMC struct { // NewMiMC returns a MiMC instance, than can be used in a gnark circuit func NewMiMC(api frontend.API) (MiMC, error) { - if constructor, ok := newMimc[api.Curve()]; ok { + if constructor, ok := newMimc[api.Compiler().Curve()]; ok { return constructor(api), nil } return MiMC{}, errors.New("unknown curve id") diff --git a/std/signature/eddsa/eddsa_test.go b/std/signature/eddsa/eddsa_test.go index 84636c0a4b..9bc4265074 100644 --- a/std/signature/eddsa/eddsa_test.go +++ b/std/signature/eddsa/eddsa_test.go @@ -138,7 +138,7 @@ func parsePoint(id ecc.ID, buf []byte) ([]byte, []byte) { func (circuit *eddsaCircuit) Define(api frontend.API) error { - params, err := twistededwards.NewEdCurve(api.Curve()) + params, err := twistededwards.NewEdCurve(api.Compiler().Curve()) if err != nil { return err } diff --git a/test/assert.go b/test/assert.go index fa807e5000..aa0f5d506e 100644 --- a/test/assert.go +++ b/test/assert.go @@ -29,8 +29,9 @@ import ( "github.com/consensys/gnark/backend/plonk" "github.com/consensys/gnark/backend/witness" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/internal/backend/compiled" - "github.com/consensys/gnark/internal/utils" + "github.com/consensys/gnark/frontend/compiled" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/stretchr/testify/require" ) @@ -340,7 +341,7 @@ func (assert *Assert) Fuzz(circuit frontend.Circuit, fuzzCount int, opts ...Test // first we clone the circuit // then we parse the frontend.Variable and set them to a random value or from our interesting pool // (% of allocations to be tuned) - w := utils.ShallowClone(circuit) + w := shallowClone(circuit) fillers := []filler{randomFiller, binaryFiller, seedFiller} @@ -398,14 +399,26 @@ func (assert *Assert) compile(circuit frontend.Circuit, curveID ecc.ID, backendI // TODO we may want to check that it was compiled with the same compile options here return ccs, nil } + + var newBuilder frontend.NewCompiler + + switch backendID { + case backend.GROTH16: + newBuilder = r1cs.NewCompiler + case backend.PLONK: + newBuilder = scs.NewCompiler + default: + panic("not implemented") + } + // else compile it and ensure it is deterministic - ccs, err := frontend.Compile(curveID, backendID, circuit, compileOpts...) + ccs, err := frontend.Compile(curveID, newBuilder, circuit, compileOpts...) // ccs, err := compiler.Compile(curveID, backendID, circuit, compileOpts...) if err != nil { return nil, err } - _ccs, err := frontend.Compile(curveID, backendID, circuit, compileOpts...) + _ccs, err := frontend.Compile(curveID, newBuilder, circuit, compileOpts...) // _ccs, err := compiler.Compile(curveID, backendID, circuit, compileOpts...) if err != nil { return nil, fmt.Errorf("%w: %v", ErrCompilationNotDeterministic, err) diff --git a/test/engine.go b/test/engine.go index 9469dedd06..4819834e50 100644 --- a/test/engine.go +++ b/test/engine.go @@ -20,11 +20,13 @@ import ( "fmt" "math/big" "path/filepath" + "reflect" "runtime" "strconv" "strings" "github.com/consensys/gnark/debug" + "github.com/consensys/gnark/frontend/schema" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" @@ -71,10 +73,10 @@ func IsSolved(circuit, witness frontend.Circuit, curveID ecc.ID, b backend.ID, o // then, we set all the variables values to the ones from the witness // clone the circuit - c := utils.ShallowClone(circuit) + c := shallowClone(circuit) // set the witness values - utils.CopyWitness(c, witness) + copyWitness(c, witness) defer func() { if r := recover(); r != nil { @@ -325,9 +327,9 @@ func (e *engine) Println(a ...frontend.Variable) { fmt.Println(sbb.String()) } -func (e *engine) NewHint(f hint.Function, inputs ...frontend.Variable) ([]frontend.Variable, error) { +func (e *engine) NewHint(f hint.Function, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { - if f.NbOutputs(e.Curve(), len(inputs)) <= 0 { + if nbOutputs <= 0 { return nil, fmt.Errorf("hint function must return at least one output") } @@ -337,7 +339,7 @@ func (e *engine) NewHint(f hint.Function, inputs ...frontend.Variable) ([]fronte v := e.toBigInt(inputs[i]) in[i] = &v } - res := make([]*big.Int, f.NbOutputs(e.Curve(), len(inputs))) + res := make([]*big.Int, nbOutputs) for i := range res { res[i] = new(big.Int) } @@ -366,9 +368,20 @@ func (e *engine) IsConstant(v frontend.Variable) bool { // ConstantValue returns the big.Int value of v // will panic if v.IsConstant() == false -func (e *engine) ConstantValue(v frontend.Variable) *big.Int { +func (e *engine) ConstantValue(v frontend.Variable) (*big.Int, bool) { r := e.toBigInt(v) - return &r + return &r, true +} + +func (e *engine) IsBoolean(v frontend.Variable) bool { + r := e.toBigInt(v) + return r.IsUint64() && r.Uint64() <= 1 +} + +func (e *engine) MarkBoolean(v frontend.Variable) { + if !e.IsBoolean(v) { + panic("mark boolean a non-boolean value") + } } func (e *engine) Tag(name string) frontend.Tag { @@ -408,3 +421,59 @@ func (e *engine) Curve() ecc.ID { func (e *engine) Backend() backend.ID { return e.backendID } + +// shallowClone clones given circuit +// this is actually a shallow copy → if the circuits contains maps or slices +// only the reference is copied. +func shallowClone(circuit frontend.Circuit) frontend.Circuit { + + cValue := reflect.ValueOf(circuit).Elem() + newCircuit := reflect.New(cValue.Type()) + newCircuit.Elem().Set(cValue) + + circuitCopy, ok := newCircuit.Interface().(frontend.Circuit) + if !ok { + panic("couldn't clone the circuit") + } + + if !reflect.DeepEqual(circuitCopy, circuit) { + panic("clone failed") + } + + return circuitCopy +} + +func copyWitness(to, from frontend.Circuit) { + var wValues []interface{} + + var collectHandler schema.LeafHandler = func(visibility schema.Visibility, name string, tInput reflect.Value) error { + v := tInput.Interface().(frontend.Variable) + + if visibility == schema.Secret || visibility == schema.Public { + if v == nil { + return fmt.Errorf("when parsing variable %s: missing assignment", name) + } + wValues = append(wValues, v) + } + return nil + } + if _, err := schema.Parse(from, tVariable, collectHandler); err != nil { + panic(err) + } + + i := 0 + var setHandler schema.LeafHandler = func(visibility schema.Visibility, name string, tInput reflect.Value) error { + if visibility == schema.Secret || visibility == schema.Public { + tInput.Set(reflect.ValueOf((wValues[i]))) + i++ + } + return nil + } + // this can't error. + _, _ = schema.Parse(to, tVariable, setHandler) + +} + +func (e *engine) Compiler() frontend.Compiler { + return e +} diff --git a/test/engine_test.go b/test/engine_test.go index a40119daf3..59808beae1 100644 --- a/test/engine_test.go +++ b/test/engine_test.go @@ -15,22 +15,22 @@ type hintCircuit struct { } func (circuit *hintCircuit) Define(api frontend.API) error { - res, err := api.NewHint(hint.IthBit, circuit.A, 3) + res, err := api.Compiler().NewHint(hint.IthBit, 1, circuit.A, 3) if err != nil { return fmt.Errorf("IthBit circuitA 3: %w", err) } a3b := res[0] - res, err = api.NewHint(hint.IthBit, circuit.A, 25) + res, err = api.Compiler().NewHint(hint.IthBit, 1, circuit.A, 25) if err != nil { return fmt.Errorf("IthBit circuitA 25: %w", err) } a25b := res[0] - res, err = api.NewHint(hint.IsZero, circuit.A) + res, err = api.Compiler().NewHint(hint.IsZero, 1, circuit.A) if err != nil { return fmt.Errorf("IsZero CircuitA: %w", err) } aisZero := res[0] - res, err = api.NewHint(hint.IsZero, circuit.B) + res, err = api.Compiler().NewHint(hint.IsZero, 1, circuit.B) if err != nil { return fmt.Errorf("IsZero, CircuitB") }