Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: frontend.Variable is now an interface, simplifies witness assignment and constant usage #180

Merged
merged 13 commits into from
Nov 17, 2021
Merged
55 changes: 16 additions & 39 deletions backend/witness/witness.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,6 @@ import (
"reflect"

"github.com/consensys/gnark-crypto/ecc"
fr_bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
fr_bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
fr_bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/fr"
fr_bn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr"
fr_bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr"
"github.com/consensys/gnark/frontend"
witness_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/witness"
witness_bls12381 "github.com/consensys/gnark/internal/backend/bls12-381/witness"
Expand Down Expand Up @@ -151,7 +146,7 @@ func WriteSequence(w io.Writer, circuit frontend.Circuit) error {
}
return nil
}
if err := parser.Visit(circuit, "", compiled.Unset, collectHandler, reflect.TypeOf(frontend.Variable{})); err != nil {
if err := parser.Visit(circuit, "", compiled.Unset, collectHandler, tVariable); err != nil {
return err
}

Expand Down Expand Up @@ -196,7 +191,7 @@ func ReadPublicFrom(r io.Reader, curveID ecc.ID, witness frontend.Circuit) (int6
}
return nil
}
_ = parser.Visit(witness, "", compiled.Unset, collectHandler, reflect.TypeOf(frontend.Variable{}))
_ = parser.Visit(witness, "", compiled.Unset, collectHandler, tVariable)

if nbPublic == 0 {
return 0, nil
Expand All @@ -212,7 +207,7 @@ func ReadPublicFrom(r io.Reader, curveID ecc.ID, witness frontend.Circuit) (int6
return 4, errors.New("invalid witness size")
}

elementSize := getElementSize(curveID)
elementSize := curveID.Info().Fr.Bytes

expectedSize := elementSize * nbPublic

Expand All @@ -227,14 +222,12 @@ func ReadPublicFrom(r io.Reader, curveID ecc.ID, witness frontend.Circuit) (int6
if err != nil {
return err
}
v := tInput.Interface().(frontend.Variable)
v.Assign(new(big.Int).SetBytes(bufElement))
tInput.Set(reflect.ValueOf(v))
tInput.Set(reflect.ValueOf(new(big.Int).SetBytes(bufElement)))
}
return nil
}

if err := parser.Visit(witness, "", compiled.Unset, reader, reflect.TypeOf(frontend.Variable{})); err != nil {
if err := parser.Visit(witness, "", compiled.Unset, reader, tVariable); err != nil {
return int64(read), err
}

Expand All @@ -258,7 +251,7 @@ func ReadFullFrom(r io.Reader, curveID ecc.ID, witness frontend.Circuit) (int64,
}
return nil
}
_ = parser.Visit(witness, "", compiled.Unset, collectHandler, reflect.TypeOf(frontend.Variable{}))
_ = parser.Visit(witness, "", compiled.Unset, collectHandler, tVariable)

if nbPublic == 0 && nbSecrets == 0 {
return 0, nil
Expand All @@ -274,7 +267,7 @@ func ReadFullFrom(r io.Reader, curveID ecc.ID, witness frontend.Circuit) (int64,
return 4, errors.New("invalid witness size")
}

elementSize := getElementSize(curveID)
elementSize := curveID.Info().Fr.Bytes
expectedSize := elementSize * (nbPublic + nbSecrets)

lr := io.LimitReader(r, int64(expectedSize*elementSize))
Expand All @@ -289,9 +282,7 @@ func ReadFullFrom(r io.Reader, curveID ecc.ID, witness frontend.Circuit) (int64,
if err != nil {
return err
}
v := tInput.Interface().(frontend.Variable)
v.Assign(new(big.Int).SetBytes(bufElement))
tInput.Set(reflect.ValueOf(v))
tInput.Set(reflect.ValueOf(new(big.Int).SetBytes(bufElement)))
}
return nil
}
Expand All @@ -305,38 +296,18 @@ func ReadFullFrom(r io.Reader, curveID ecc.ID, witness frontend.Circuit) (int64,
}

// public
if err := parser.Visit(witness, "", compiled.Unset, publicReader, reflect.TypeOf(frontend.Variable{})); err != nil {
if err := parser.Visit(witness, "", compiled.Unset, publicReader, tVariable); err != nil {
return int64(read), err
}

// secret
if err := parser.Visit(witness, "", compiled.Unset, secretReader, reflect.TypeOf(frontend.Variable{})); err != nil {
if err := parser.Visit(witness, "", compiled.Unset, secretReader, tVariable); err != nil {
return int64(read), err
}

return int64(read), nil
}

func getElementSize(curve ecc.ID) int {
// now compute expected size from field element size.
var elementSize int
switch curve {
case ecc.BLS12_377:
elementSize = fr_bls12377.Bytes
case ecc.BLS12_381:
elementSize = fr_bls12381.Bytes
case ecc.BLS24_315:
elementSize = fr_bls24315.Bytes
case ecc.BN254:
elementSize = fr_bn254.Bytes
case ecc.BW6_761:
elementSize = fr_bw6761.Bytes
default:
panic("not implemented")
}
return elementSize
}

// ToJSON outputs a JSON string with variableName: value
// values are first converted to field element (mod base curve modulus)
func ToJSON(witness frontend.Circuit, curveID ecc.ID) (string, error) {
Expand All @@ -355,3 +326,9 @@ func ToJSON(witness frontend.Circuit, curveID ecc.ID) (string, error) {
panic("not implemented")
}
}

var tVariable reflect.Type

func init() {
tVariable = reflect.ValueOf(struct{ A frontend.Variable }{}).FieldByName("A").Type()
}
12 changes: 6 additions & 6 deletions backend/witness/witness_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ type circuit struct {
E frontend.Variable
}

func (circuit *circuit) Define(curveID ecc.ID, api frontend.API) error {
func (circuit *circuit) Define(api frontend.API) error {
return nil
}

func TestReconstructionPublic(t *testing.T) {
assert := require.New(t)

var wPublic, wPublicReconstructed circuit
wPublic.X.Assign(new(big.Int).SetInt64(42))
wPublic.Y.Assign(new(big.Int).SetInt64(8000))
wPublic.X = new(big.Int).SetInt64(42)
wPublic.Y = new(big.Int).SetInt64(8000)

var buf bytes.Buffer
written, err := WritePublicTo(&buf, ecc.BN254, &wPublic)
Expand All @@ -48,9 +48,9 @@ func TestReconstructionFull(t *testing.T) {
assert := require.New(t)

var wFull, wFullReconstructed circuit
wFull.X.Assign(new(big.Int).SetInt64(42))
wFull.Y.Assign(new(big.Int).SetInt64(8000))
wFull.E.Assign(new(big.Int).SetInt64(1))
wFull.X = new(big.Int).SetInt64(42)
wFull.Y = new(big.Int).SetInt64(8000)
wFull.E = new(big.Int).SetInt64(1)

var buf bytes.Buffer
written, err := WriteFullTo(&buf, ecc.BN254, &wFull)
Expand Down
30 changes: 15 additions & 15 deletions debug_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type printlnCircuit struct {
A, B frontend.Variable
}

func (circuit *printlnCircuit) Define(curveID ecc.ID, api frontend.API) error {
func (circuit *printlnCircuit) Define(api frontend.API) error {
c := api.Add(circuit.A, circuit.B)
api.Println(c, "is the addition")
d := api.Mul(circuit.A, c)
Expand All @@ -39,8 +39,8 @@ func TestPrintln(t *testing.T) {
assert := require.New(t)

var circuit, witness printlnCircuit
witness.A.Assign(2)
witness.B.Assign(11)
witness.A = 2
witness.B = 11

var expected bytes.Buffer
expected.WriteString("debug_test.go:25 13 is the addition\n")
Expand All @@ -66,7 +66,7 @@ type divBy0Trace struct {
A, B, C frontend.Variable
}

func (circuit *divBy0Trace) Define(curveID ecc.ID, api frontend.API) error {
func (circuit *divBy0Trace) Define(api frontend.API) error {
d := api.Add(circuit.B, circuit.C)
api.Div(circuit.A, d)
return nil
Expand All @@ -76,9 +76,9 @@ func TestTraceDivBy0(t *testing.T) {
assert := require.New(t)

var circuit, witness divBy0Trace
witness.A.Assign(2)
witness.B.Assign(-2)
witness.C.Assign(2)
witness.A = 2
witness.B = -2
witness.C = 2

{
_, err := getGroth16Trace(&circuit, &witness)
Expand All @@ -103,7 +103,7 @@ type notEqualTrace struct {
A, B, C frontend.Variable
}

func (circuit *notEqualTrace) Define(curveID ecc.ID, api frontend.API) error {
func (circuit *notEqualTrace) Define(api frontend.API) error {
d := api.Add(circuit.B, circuit.C)
api.AssertIsEqual(circuit.A, d)
return nil
Expand All @@ -113,9 +113,9 @@ func TestTraceNotEqual(t *testing.T) {
assert := require.New(t)

var circuit, witness notEqualTrace
witness.A.Assign(1)
witness.B.Assign(24)
witness.C.Assign(42)
witness.A = 1
witness.B = 24
witness.C = 42

{
_, err := getGroth16Trace(&circuit, &witness)
Expand All @@ -140,7 +140,7 @@ type notBooleanTrace struct {
B, C frontend.Variable
}

func (circuit *notBooleanTrace) Define(curveID ecc.ID, api frontend.API) error {
func (circuit *notBooleanTrace) Define(api frontend.API) error {
d := api.Add(circuit.B, circuit.C)
api.AssertIsBoolean(d)
return nil
Expand All @@ -150,9 +150,9 @@ func TestTraceNotBoolean(t *testing.T) {
assert := require.New(t)

var circuit, witness notBooleanTrace
// witness.A.Assign(1)
witness.B.Assign(24)
witness.C.Assign(42)
// witness.A = 1
witness.B = 24
witness.C = 42

{
_, err := getGroth16Trace(&circuit, &witness)
Expand Down
3 changes: 1 addition & 2 deletions examples/cubic/cubic.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package cubic

import (
"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/frontend"
)

Expand All @@ -30,7 +29,7 @@ type Circuit struct {

// Define declares the circuit constraints
// x**3 + x + 5 == y
func (circuit *Circuit) Define(curveID ecc.ID, api frontend.API) error {
func (circuit *Circuit) Define(api frontend.API) error {
x3 := api.Mul(circuit.X, circuit.X, circuit.X)
api.AssertIsEqual(circuit.Y, api.Add(x3, circuit.X, 5))
return nil
Expand Down
9 changes: 4 additions & 5 deletions examples/cubic/cubic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ package cubic
import (
"testing"

"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/test"
)

Expand All @@ -27,13 +26,13 @@ func TestCubicEquation(t *testing.T) {
var cubicCircuit Circuit

assert.ProverFailed(&cubicCircuit, &Circuit{
X: frontend.Value(42),
Y: frontend.Value(42),
X: 42,
Y: 42,
})

assert.ProverSucceeded(&cubicCircuit, &Circuit{
X: frontend.Value(3),
Y: frontend.Value(35),
X: 3,
Y: 35,
})

}
5 changes: 2 additions & 3 deletions examples/exponentiate/exponentiate.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package exponentiate

import (
"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/frontend"
)

Expand All @@ -32,13 +31,13 @@ type Circuit struct {

// Define declares the circuit's constraints
// y == x**e
func (circuit *Circuit) Define(curveID ecc.ID, api frontend.API) error {
func (circuit *Circuit) Define(api frontend.API) error {

// number of bits of exponent
const bitSize = 8

// specify constraints
output := api.Constant(1)
output := frontend.Variable(1)
bits := api.ToBinary(circuit.E, bitSize)
api.ToBinary(circuit.E, bitSize)

Expand Down
13 changes: 6 additions & 7 deletions examples/exponentiate/exponentiate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ package exponentiate
import (
"testing"

"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/test"
)

Expand All @@ -28,15 +27,15 @@ func TestExponentiateGroth16(t *testing.T) {
var expCircuit Circuit

assert.ProverFailed(&expCircuit, &Circuit{
X: frontend.Value(2),
E: frontend.Value(12),
Y: frontend.Value(4095),
X: 2,
E: 12,
Y: 4095,
})

assert.ProverSucceeded(&expCircuit, &Circuit{
X: frontend.Value(2),
E: frontend.Value(12),
Y: frontend.Value(4096),
X: 2,
E: 12,
Y: 4096,
})

}
5 changes: 2 additions & 3 deletions examples/mimc/mimc.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package mimc

import (
"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/hash/mimc"
)
Expand All @@ -31,9 +30,9 @@ type Circuit struct {

// Define declares the circuit's constraints
// Hash = mimc(PreImage)
func (circuit *Circuit) Define(curveID ecc.ID, api frontend.API) error {
func (circuit *Circuit) Define(api frontend.API) error {
// hash function
mimc, _ := mimc.NewMiMC("seed", curveID, api)
mimc, _ := mimc.NewMiMC("seed", api)

// specify constraints
// mimc(preImage) == hash
Expand Down
9 changes: 4 additions & 5 deletions examples/mimc/mimc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"testing"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/test"
)

Expand All @@ -28,13 +27,13 @@ func TestPreimage(t *testing.T) {
var mimcCircuit Circuit

assert.ProverFailed(&mimcCircuit, &Circuit{
Hash: frontend.Value(42),
PreImage: frontend.Value(42),
Hash: 42,
PreImage: 42,
})

assert.ProverSucceeded(&mimcCircuit, &Circuit{
PreImage: frontend.Value(35),
Hash: frontend.Value("16130099170765464552823636852555369511329944820189892919423002775646948828469"),
PreImage: 35,
Hash: "16130099170765464552823636852555369511329944820189892919423002775646948828469",
}, test.WithCurves(ecc.BN254))

}
Loading