Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Bandersnatch GLV scalar multiplication #1271

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions std/algebra/native/twistededwards/curve.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,7 @@ func (c *curve) AssertIsOnCurve(p1 Point) {
}
func (c *curve) ScalarMul(p1 Point, scalar frontend.Variable) Point {
var p Point
if c.endo != nil {
// TODO restore
// this is disabled until this issue is solved https://github.com/ConsenSys/gnark/issues/268
// p.scalarMulGLV(c.api, &p1, scalar, c.params, c.endo)
p.scalarMul(c.api, &p1, scalar, c.params)
} else {
p.scalarMul(c.api, &p1, scalar, c.params)
}
p.scalarMul(c.api, &p1, scalar, c.params, c.endo)
return p
}
func (c *curve) DoubleBaseScalarMul(p1, p2 Point, s1, s2 frontend.Variable) Point {
Expand Down
28 changes: 28 additions & 0 deletions std/algebra/native/twistededwards/curve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,34 @@ func (circuit *addCircuit) Define(api frontend.API) error {
api.AssertIsEqual(res.Y, circuit.ScalarMulResult.Y)
}

{
// scalar mul zero-scalar edge-case
res := curve.ScalarMul(circuit.P2, 0)
api.AssertIsEqual(res.X, 0)
api.AssertIsEqual(res.Y, 1)
}

{
// scalar mul zero-point edge-case
res := curve.ScalarMul(Point{0, 1}, circuit.S2)
api.AssertIsEqual(res.X, 0)
api.AssertIsEqual(res.Y, 1)
}

{
// scalar mul zero-scalar and zero-point edge-case
res := curve.ScalarMul(Point{0, 1}, 0)
api.AssertIsEqual(res.X, 0)
api.AssertIsEqual(res.Y, 1)
}

{
// scalar mul one-scalar edge-case
res := curve.ScalarMul(circuit.P2, 1)
api.AssertIsEqual(res.X, circuit.P2.X)
api.AssertIsEqual(res.Y, circuit.P2.Y)
}

{
// double scalar mul
res := curve.DoubleBaseScalarMul(circuit.P1, circuit.P2, circuit.S1, circuit.S2)
Expand Down
136 changes: 136 additions & 0 deletions std/algebra/native/twistededwards/hints.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
package twistededwards

import (
"fmt"
"math/big"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/constraint/solver"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/emulated"
"github.com/consensys/gnark/std/math/emulated/emparams"
)

func GetHints() []solver.Hint {
return []solver.Hint{
decomposeScalar,
decomposeScalarSigns,
decompose,
}
}

func init() {
solver.RegisterHint(GetHints()...)
}

type glvParams struct {
lambda, order big.Int
glvBasis ecc.Lattice
}

func decomposeScalar(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error {
return emulated.UnwrapHintWithNativeInput(nativeInputs, nativeOutputs, func(nnMod *big.Int, nninputs, nnOutputs []*big.Int) error {
if len(nninputs) != 1 {
return fmt.Errorf("expecting one input")
}
if len(nnOutputs) != 2 {
return fmt.Errorf("expecting two outputs")
}
var glv glvParams
glv.lambda.SetString("8913659658109529928382530854484400854125314752504019737736543920008458395397", 10)
glv.order.SetString("13108968793781547619861935127046491459309155893440570251786403306729687672801", 10)
ecc.PrecomputeLattice(&glv.order, &glv.lambda, &glv.glvBasis)
sp := ecc.SplitScalar(nninputs[0], &glv.glvBasis)
nnOutputs[0].Set(&(sp[0]))
nnOutputs[1].Set(&(sp[1]))

if nnOutputs[1].Sign() == -1 {
nnOutputs[1].Neg(nnOutputs[1])
}

return nil
})
}

func decomposeScalarSigns(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error {
if len(inputs) != 1 {
return fmt.Errorf("expecting one input")
}
if len(outputs) != 1 {
return fmt.Errorf("expecting one output")
}
var glv glvParams
glv.lambda.SetString("8913659658109529928382530854484400854125314752504019737736543920008458395397", 10)
glv.order.SetString("13108968793781547619861935127046491459309155893440570251786403306729687672801", 10)
ecc.PrecomputeLattice(&glv.order, &glv.lambda, &glv.glvBasis)
sp := ecc.SplitScalar(inputs[0], &glv.glvBasis)
outputs[0].SetUint64(0)
if sp[1].Sign() == -1 {
outputs[0].SetUint64(1)
}

return nil
}

func callDecomposeScalar(api frontend.API, s frontend.Variable) (s1, s2, s3 frontend.Variable) {
var fr emparams.BandersnatchFr
var glv glvParams
glv.lambda.SetString("8913659658109529928382530854484400854125314752504019737736543920008458395397", 10)

sapi, err := emulated.NewField[emparams.BandersnatchFr](api)
if err != nil {
panic(err)
}
// compute the decomposition using a hint. We have to use the emulated
// version which takes native input and outputs non-native outputs.
//
// the hints allow to decompose the scalar s into s1 and s2 such that
// s1 + λ * s2 == s mod r,
// where λ is third root of one in 𝔽_r.
sd, err := sapi.NewHintWithNativeInput(decomposeScalar, 2, s)
if err != nil {
panic(err)
}
bit, err := api.NewHint(decomposeScalarSigns, 1, s)
if err != nil {
panic(err)
}
// lambda as nonnative element
lambdaEmu := sapi.NewElement(glv.lambda)
// the scalar as nonnative element. We need to split at 64 bits.
limbs, err := api.NewHint(decompose, int(fr.NbLimbs()), s)
if err != nil {
panic(err)
}
semu := sapi.NewElement(limbs)
// we negated s2 in decomposeScalar so we check instead:
// s1 + λ * s2 == s mod r
_s1 := sapi.Select(bit[0], sapi.Neg(sd[1]), sd[1])
rhs := sapi.MulNoReduce(_s1, lambdaEmu)
rhs = sapi.Add(rhs, sd[0])
sapi.AssertIsEqual(rhs, semu)

s1 = 0
s2 = 0
s3 = bit[0]
b := big.NewInt(1)
for i := range sd[0].Limbs {
s1 = api.Add(s1, api.Mul(sd[0].Limbs[i], b))
s2 = api.Add(s2, api.Mul(sd[1].Limbs[i], b))
b.Lsh(b, 64)
}
return s1, s2, s3
}

func decompose(mod *big.Int, inputs, outputs []*big.Int) error {
if len(inputs) != 1 && len(outputs) != 4 {
return fmt.Errorf("input/output length mismatch")
}
tmp := new(big.Int).Set(inputs[0])
mask := new(big.Int).SetUint64(^uint64(0))
for i := 0; i < 4; i++ {
outputs[i].And(tmp, mask)
tmp.Rsh(tmp, 64)
}
return nil
}
97 changes: 90 additions & 7 deletions std/algebra/native/twistededwards/point.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package twistededwards

import (
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/internal/frontendtype"
)

// neg computes the negative of a point in SNARK coordinates
Expand Down Expand Up @@ -95,17 +96,12 @@ func (p *Point) double(api frontend.API, p1 *Point, curve *CurveParams) *Point {
return p
}

// scalarMul computes the scalar multiplication of a point on a twisted Edwards curve
// scalarMulGeneric computes the scalar multiplication of a point on a twisted Edwards curve
// p1: base point (as snark point)
// curve: parameters of the Edwards curve
// scal: scalar as a SNARK constraint
// Standard left to right double and add
func (p *Point) scalarMul(api frontend.API, p1 *Point, scalar frontend.Variable, curve *CurveParams, endo ...*EndoParams) *Point {
if len(endo) == 1 && endo[0] != nil {
// use glv
return p.scalarMulGLV(api, p1, scalar, curve, endo[0])
}

func (p *Point) scalarMulGeneric(api frontend.API, p1 *Point, scalar frontend.Variable, curve *CurveParams) *Point {
// first unpack the scalar
b := api.ToBinary(scalar)

Expand Down Expand Up @@ -142,6 +138,28 @@ func (p *Point) scalarMul(api frontend.API, p1 *Point, scalar frontend.Variable,
return p
}

// scalarMul computes the scalar multiplication of a point on a twisted Edwards curve
// p1: base point (as snark point)
// curve: parameters of the Edwards curve
// scal: scalar as a SNARK constraint
// Standard left to right double and add
func (p *Point) scalarMul(api frontend.API, p1 *Point, scalar frontend.Variable, curve *CurveParams, endo ...*EndoParams) *Point {
if ft, ok := api.(frontendtype.FrontendTyper); ok {
switch ft.FrontendType() {
case frontendtype.R1CS:
if len(endo) == 1 && endo[0] != nil {
// use glv
return p.scalarMulGLV(api, p1, scalar, curve, endo[0])
} else {
return p.scalarMulGeneric(api, p1, scalar, curve)
}
case frontendtype.SCS:
return p.scalarMulGeneric(api, p1, scalar, curve)
}
}
return p.scalarMulGeneric(api, p1, scalar, curve)
}

// doubleBaseScalarMul computes s1*P1+s2*P2
// where P1 and P2 are points on a twisted Edwards curve
// and s1, s2 scalars.
Expand Down Expand Up @@ -172,3 +190,68 @@ func (p *Point) doubleBaseScalarMul(api frontend.API, p1, p2 *Point, s1, s2 fron

return p
}

// GLV scalar multiplication

// phi endomorphism √-2 ∈ 𝒪₋₈
// (x,y) → λ × (x,y) s.t. λ² = -2 mod Order
func (p *Point) phi(api frontend.API, p1 *Point, endo *EndoParams) *Point {

xy := api.Mul(p1.X, p1.Y)
yy := api.Mul(p1.Y, p1.Y)
f := api.Sub(1, yy)
f = api.Mul(f, endo.Endo[1])
g := api.Add(yy, endo.Endo[0])
g = api.Mul(g, endo.Endo[0])
h := api.Sub(yy, endo.Endo[0])

p.X = api.DivUnchecked(f, xy)
p.Y = api.DivUnchecked(g, h)

return p
}

// scalarMulGLV computes the scalar multiplication of a point on a twisted Edwards curve à la GLV
// p1: base point (as snark point)
// curve: parameters of the Edwards curve
// scal: scalar as a SNARK constraint
// Standard left to right double and add
func (p *Point) scalarMulGLV(api frontend.API, p1 *Point, scalar frontend.Variable, curve *CurveParams, endo *EndoParams) *Point {

// the hints allow to decompose the scalar s into s1 and s2 such that
// s1 + λ * s2 == s mod Order,
// with λ s.t. λ² = -2 mod Order.
s1, s2, s3 := callDecomposeScalar(api, scalar)

n := 127

b1 := api.ToBinary(s1, n)
b2 := api.ToBinary(s2, n)

var _p1, res, p2, p3, tmp Point
// the endomorphism is not defined for point with X=0 or Y=0 Y=0 points are
// not on the prime subgroup and X=0 point is the zero-point (0,1).
// So we replace p1=(0,1) with a dummy point (3,1) and continue at the end
// we return (0,1).
selector := api.IsZero(p1.X)
_p1.X = api.Select(selector, 3, p1.X)
_p1.Y = p1.Y
p2.phi(api, &_p1, endo)
p2.X = api.Select(s3, api.Neg(p2.X), p2.X)
p3.add(api, &_p1, &p2, curve)

res.X = api.Lookup2(b1[n-1], b2[n-1], 0, _p1.X, p2.X, p3.X)
res.Y = api.Lookup2(b1[n-1], b2[n-1], 1, _p1.Y, p2.Y, p3.Y)

for i := n - 2; i >= 0; i-- {
res.double(api, &res, curve)
tmp.X = api.Lookup2(b1[i], b2[i], 0, _p1.X, p2.X, p3.X)
tmp.Y = api.Lookup2(b1[i], b2[i], 1, _p1.Y, p2.Y, p3.Y)
res.add(api, &res, &tmp, curve)
}

p.X = api.Select(selector, 0, res.X)
p.Y = api.Select(selector, 1, res.Y)

return p
}
Loading