Skip to content

Commit

Permalink
Merge pull request #271 from ConsenSys/refactor-compiled
Browse files Browse the repository at this point in the history
`frontend/` refactor: separate responsabilities to `Builder`, `Compiler` and `API` interfaces
  • Loading branch information
gbotrel authored Mar 1, 2022
2 parents 7149365 + 88838e8 commit 534a171
Show file tree
Hide file tree
Showing 109 changed files with 2,484 additions and 2,784 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func (circuit *CubicCircuit) Define(api frontend.API) error {

// compiles our circuit into a R1CS
var circuit CubicCircuit
ccs, err := frontend.Compile(ecc.BN254, backend.GROTH16, &circuit)
ccs, err := frontend.Compile(ecc.BN254, r1cs.NewCompiler, &circuit)

// groth16 zkSNARK: Setup
pk, vk, err := groth16.Setup(ccs)
Expand Down
5 changes: 0 additions & 5 deletions backend/groth16/groth16.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (
"github.com/consensys/gnark/backend"
"github.com/consensys/gnark/backend/witness"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/cs/r1cs"
backend_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/cs"
backend_bls12381 "github.com/consensys/gnark/internal/backend/bls12-381/cs"
backend_bls24315 "github.com/consensys/gnark/internal/backend/bls24-315/cs"
Expand All @@ -52,10 +51,6 @@ import (
groth16_bw6761 "github.com/consensys/gnark/internal/backend/bw6-761/groth16"
)

func init() {
frontend.RegisterDefaultBuilder(backend.GROTH16, r1cs.NewBuilder)
}

// TODO @gbotrel document hint functions here and in assert

type groth16Object interface {
Expand Down
19 changes: 17 additions & 2 deletions backend/hint/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@ var initBuiltinOnce sync.Once

func init() {
initBuiltinOnce.Do(func() {
IsZero = NewStaticHint(builtinIsZero, 1, 1)
IsZero = NewStaticHint(builtinIsZero)
Register(IsZero)
IthBit = NewStaticHint(builtinIthBit, 2, 1)
IthBit = NewStaticHint(builtinIthBit)
Register(IthBit)
NBits = NewStaticHint(builtinNBits)
Register(NBits)
})
}

// TODO FIXME these may be redefined easily by an external package

// The package provides the following built-in hint functions. All built-in hint
// functions are registered in the registry.
var (
Expand All @@ -30,6 +34,9 @@ var (
// integer inputs i and n, takes the little-endian bit representation of n and
// returns its i-th bit.
IthBit Function

// NBits returns the n first bits of the input. Expects one argument: n.
NBits Function
)

func builtinIsZero(curveID ecc.ID, inputs []*big.Int, results []*big.Int) error {
Expand Down Expand Up @@ -63,3 +70,11 @@ func builtinIthBit(_ ecc.ID, inputs []*big.Int, results []*big.Int) error {
result.SetUint64(uint64(inputs[0].Bit(int(inputs[1].Uint64()))))
return nil
}

func builtinNBits(_ ecc.ID, inputs []*big.Int, results []*big.Int) error {
n := inputs[0]
for i := 0; i < len(results); i++ {
results[i].SetUint64(uint64(n.Bit(i)))
}
return nil
}
61 changes: 14 additions & 47 deletions backend/hint/hint.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ type ID uint32
// StaticFunction is a function which takes a constant number of inputs and
// returns a constant number of outputs. Use NewStaticHint() to construct an
// instance compatible with Function interface.
type StaticFunction func(curveID ecc.ID, inputs []*big.Int, res []*big.Int) error
type StaticFunction func(curveID ecc.ID, inputs []*big.Int, outputs []*big.Int) error

// Function defines an annotated hint function. To initialize a hint function
// with static number of inputs and outputs, use NewStaticHint().
Expand All @@ -91,21 +91,18 @@ type Function interface {
UUID() ID

// Call is invoked by the framework to obtain the result from inputs.
// The length of res is NbOutputs() and every element is
// already initialized (but not necessarily to zero as the elements may be
// obtained from cache). A returned non-nil error will be propagated.
Call(curveID ecc.ID, inputs []*big.Int, res []*big.Int) error

// NbOutputs returns the total number of outputs by the function when
// invoked on the curveID with nInputs number of inputs. The number of
// outputs must be at least one and the framework errors otherwise.
NbOutputs(curveID ecc.ID, nInputs int) (nOutputs int)
// Elements in outputs are not guaranteed to be initialized to 0
Call(curveID ecc.ID, inputs []*big.Int, outputs []*big.Int) error

// String returns a human-readable description of the function used in logs
// and debug messages.
String() string
}

func NewStaticHint(fn StaticFunction) Function {
return fn
}

// UUID is a reference function for computing the hint ID based on a function
// and additional context values ctx. A change in any of the inputs modifies the
// returned value and thus this function can be used to compute the hint ID for
Expand All @@ -127,46 +124,16 @@ func UUID(fn StaticFunction, ctx ...uint64) ID {
return ID(hf.Sum32())
}

// staticArgumentsFunction defines a function where the number of inputs and
// outputs is constant.
type staticArgumentsFunction struct {
fn StaticFunction
nIn int
nOut int
}

// NewStaticHint returns an Function where the number of inputs and outputs is
// constant. UUID is computed by combining fn, nIn and nOut and thus it is legal
// to defined multiple AnnotatedFunctions on the same fn with different nIn and
// nOut.
func NewStaticHint(fn StaticFunction, nIn, nOut int) Function {
return &staticArgumentsFunction{
fn: fn,
nIn: nIn,
nOut: nOut,
}
}

func (h *staticArgumentsFunction) Call(curveID ecc.ID, inputs []*big.Int, res []*big.Int) error {
if len(inputs) != h.nIn {
return fmt.Errorf("input has %d elements, expected %d", len(inputs), h.nIn)
}
if len(res) != h.nOut {
return fmt.Errorf("result has %d elements, expected %d", len(res), h.nOut)
}
return h.fn(curveID, inputs, res)
}

func (h *staticArgumentsFunction) NbOutputs(_ ecc.ID, _ int) int {
return h.nOut
func (h StaticFunction) Call(curveID ecc.ID, inputs []*big.Int, res []*big.Int) error {
return h(curveID, inputs, res)
}

func (h *staticArgumentsFunction) UUID() ID {
return UUID(h.fn, uint64(h.nIn), uint64(h.nOut))
func (h StaticFunction) UUID() ID {
return UUID(h)
}

func (h *staticArgumentsFunction) String() string {
fnptr := reflect.ValueOf(h.fn).Pointer()
func (h StaticFunction) String() string {
fnptr := reflect.ValueOf(h).Pointer()
name := runtime.FuncForPC(fnptr).Name()
return fmt.Sprintf("%s([%d]*big.Int, [%d]*big.Int) at (%x)", name, h.nIn, h.nOut, fnptr)
return fmt.Sprintf("%s([?]*big.Int, [?]*big.Int) at (%x)", name, fnptr)
}
5 changes: 0 additions & 5 deletions backend/plonk/plonk.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import (
"github.com/consensys/gnark-crypto/kzg"
"github.com/consensys/gnark/backend"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/cs/plonk"

"github.com/consensys/gnark/backend/witness"
cs_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/cs"
Expand Down Expand Up @@ -58,10 +57,6 @@ import (
kzg_bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/kzg"
)

func init() {
frontend.RegisterDefaultBuilder(backend.PLONK, plonk.NewBuilder)
}

// Proof represents a Plonk proof generated by plonk.Prove
//
// it's underlying implementation is curve specific (see gnark/internal/backend)
Expand Down
21 changes: 17 additions & 4 deletions circuitstats_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/backend"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/cs/r1cs"
"github.com/consensys/gnark/frontend/cs/scs"
"github.com/consensys/gnark/internal/backend/circuits"
"github.com/consensys/gnark/test"
)
Expand All @@ -26,19 +28,30 @@ func TestCircuitStatistics(t *testing.T) {
for _, curve := range ecc.Implemented() {
for _, b := range backend.Implemented() {
curve := curve
b := b
backendID := b
name := k
// copy the circuit now in case assert calls t.Parallel()
tData := circuits.Circuits[k]
assert.Run(func(assert *test.Assert) {
ccs, err := frontend.Compile(curve, b, tData.Circuit)
var newCompiler frontend.NewCompiler

switch backendID {
case backend.GROTH16:
newCompiler = r1cs.NewCompiler
case backend.PLONK:
newCompiler = scs.NewCompiler
default:
panic("not implemented")
}

ccs, err := frontend.Compile(curve, newCompiler, tData.Circuit)
assert.NoError(err)

// ensure we didn't introduce regressions that make circuits less efficient
nbConstraints := ccs.GetNbConstraints()
internal, secret, public := ccs.GetNbVariables()
checkStats(assert, name, nbConstraints, internal, secret, public, curve, b)
}, name, curve.String(), b.String())
checkStats(assert, name, nbConstraints, internal, secret, public, curve, backendID)
}, name, curve.String(), backendID.String())
}
}

Expand Down
16 changes: 9 additions & 7 deletions debug_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"github.com/consensys/gnark/backend/groth16"
"github.com/consensys/gnark/backend/plonk"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/cs/r1cs"
"github.com/consensys/gnark/frontend/cs/scs"
"github.com/consensys/gnark/test"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -43,11 +45,11 @@ func TestPrintln(t *testing.T) {
witness.B = 11

var expected bytes.Buffer
expected.WriteString("debug_test.go:25 13 is the addition\n")
expected.WriteString("debug_test.go:27 26 42\n")
expected.WriteString("debug_test.go:29 bits 1\n")
expected.WriteString("debug_test.go:30 circuit {A: 2, B: 11}\n")
expected.WriteString("debug_test.go:34 m .*\n")
expected.WriteString("debug_test.go:27 13 is the addition\n")
expected.WriteString("debug_test.go:29 26 42\n")
expected.WriteString("debug_test.go:31 bits 1\n")
expected.WriteString("debug_test.go:32 circuit {A: 2, B: 11}\n")
expected.WriteString("debug_test.go:36 m .*\n")

{
trace, _ := getGroth16Trace(&circuit, &witness)
Expand Down Expand Up @@ -172,7 +174,7 @@ func TestTraceNotBoolean(t *testing.T) {
}

func getPlonkTrace(circuit, w frontend.Circuit) (string, error) {
ccs, err := frontend.Compile(ecc.BN254, backend.PLONK, circuit)
ccs, err := frontend.Compile(ecc.BN254, scs.NewCompiler, circuit)
if err != nil {
return "", err
}
Expand All @@ -196,7 +198,7 @@ func getPlonkTrace(circuit, w frontend.Circuit) (string, error) {
}

func getGroth16Trace(circuit, w frontend.Circuit) (string, error) {
ccs, err := frontend.Compile(ecc.BN254, backend.GROTH16, circuit)
ccs, err := frontend.Compile(ecc.BN254, r1cs.NewCompiler, circuit)
if err != nil {
return "", err
}
Expand Down
4 changes: 2 additions & 2 deletions examples/plonk/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ import (
"log"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/backend"
"github.com/consensys/gnark/backend/plonk"
"github.com/consensys/gnark/frontend/cs/scs"
"github.com/consensys/gnark/internal/backend/bn254/cs"
"github.com/consensys/gnark/test"

Expand Down Expand Up @@ -73,7 +73,7 @@ func main() {
var circuit Circuit

// // building the circuit...
ccs, err := frontend.Compile(ecc.BN254, backend.PLONK, &circuit)
ccs, err := frontend.Compile(ecc.BN254, scs.NewCompiler, &circuit)
if err != nil {
fmt.Println("circuit compilation error")
}
Expand Down
2 changes: 1 addition & 1 deletion examples/rollup/circuit.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ type TransferConstraints struct {

func (circuit *Circuit) postInit(api frontend.API) error {
// edward curve params
params, err := twistededwards.NewEdCurve(api.Curve())
params, err := twistededwards.NewEdCurve(api.Compiler().Curve())
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions examples/serialization/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ import (
"github.com/fxamacker/cbor/v2"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/backend"
"github.com/consensys/gnark/backend/groth16"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/cs/r1cs"

"github.com/consensys/gnark/examples/cubic"
)
Expand All @@ -17,7 +17,7 @@ func main() {
var circuit cubic.Circuit

// compile a circuit
_r1cs, _ := frontend.Compile(ecc.BN254, backend.GROTH16, &circuit)
_r1cs, _ := frontend.Compile(ecc.BN254, r1cs.NewCompiler, &circuit)

// R1CS implements io.WriterTo and io.ReaderFrom
var buf bytes.Buffer
Expand Down
46 changes: 20 additions & 26 deletions frontend/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ type API interface {

// ---------------------------------------------------------------------------------------------
// Bit operations
// TODO @gbotrel move bit operations in std/math/bits

// ToBinary unpacks a Variable in binary,
// n is the number of bits to select (starting from lsb)
Expand Down Expand Up @@ -111,39 +112,32 @@ type API interface {
// whose value will be resolved at runtime when computed by the solver
Println(a ...Variable)

// NewHint initializes internal variables whose value will be evaluated
// using the provided hint function at run time from the inputs. Inputs must
// be either variables or convertible to *big.Int. The function returns an
// error if the number of inputs is not compatible with f.
//
// The hint function is provided at the proof creation time and is not
// embedded into the circuit. From the backend point of view, the variable
// returned by the hint function is equivalent to the user-supplied witness,
// but its actual value is assigned by the solver, not the caller.
//
// No new constraints are added to the newly created wire and must be added
// manually in the circuit. Failing to do so leads to solver failure.
NewHint(f hint.Function, inputs ...Variable) ([]Variable, error)
// Compiler returns the compiler object for advanced circuit development
Compiler() Compiler

// Deprecated APIs

// Tag creates a tag at a given place in a circuit. The state of the tag may contain informations needed to
// measure constraints, variables and coefficients creations through AddCounter
// NewHint is a shorcut to api.Compiler().NewHint()
// Deprecated: use api.Compiler().NewHint() instead
NewHint(f hint.Function, nbOutputs int, inputs ...Variable) ([]Variable, error)

// Tag is a shorcut to api.Compiler().Tag()
// Deprecated: use api.Compiler().Tag() instead
Tag(name string) Tag

// AddCounter measures the number of constraints, variables and coefficients created between two tags
// note that the PlonK statistics are contextual since there is a post-compile phase where linear expressions
// are factorized. That is, measuring 2 times the "repeating" piece of circuit may give less constraints the second time
// AddCounter is a shorcut to api.Compiler().AddCounter()
// Deprecated: use api.Compiler().AddCounter() instead
AddCounter(from, to Tag)

// IsConstant returns true if v is a constant known at compile time
IsConstant(v Variable) bool

// ConstantValue returns the big.Int value of v. It
// panics if v.IsConstant() == false
ConstantValue(v Variable) *big.Int
// ConstantValue is a shorcut to api.Compiler().ConstantValue()
// Deprecated: use api.Compiler().ConstantValue() instead
ConstantValue(v Variable) (*big.Int, bool)

// CurveID returns the ecc.ID injected by the compiler
// Curve is a shorcut to api.Compiler().Curve()
// Deprecated: use api.Compiler().Curve() instead
Curve() ecc.ID

// Backend returns the backend.ID injected by the compiler
// Backend is a shorcut to api.Compiler().Backend()
// Deprecated: use api.Compiler().Backend() instead
Backend() backend.ID
}
2 changes: 1 addition & 1 deletion frontend/ccs.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ import (
"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/backend"
"github.com/consensys/gnark/backend/witness"
"github.com/consensys/gnark/frontend/compiled"
"github.com/consensys/gnark/frontend/schema"
"github.com/consensys/gnark/internal/backend/compiled"
)

// CompiledConstraintSystem interface that a compiled (=typed, and correctly routed)
Expand Down
Loading

0 comments on commit 534a171

Please sign in to comment.