From 9c6dd9b7245926011fddae0425e6f07b071ca66b Mon Sep 17 00:00:00 2001 From: Youssef El Housni Date: Thu, 5 Sep 2024 19:37:54 -0400 Subject: [PATCH 1/5] fix: GLV for Bandersnatch --- std/algebra/native/twistededwards/curve.go | 5 +- std/algebra/native/twistededwards/hints.go | 105 ++++++++++++++++++ .../native/twistededwards/scalarmul_glv.go | 82 ++------------ std/math/emulated/emparams/emparams.go | 32 ++++++ 4 files changed, 148 insertions(+), 76 deletions(-) create mode 100644 std/algebra/native/twistededwards/hints.go diff --git a/std/algebra/native/twistededwards/curve.go b/std/algebra/native/twistededwards/curve.go index bcc5f36119..491d4ee3de 100644 --- a/std/algebra/native/twistededwards/curve.go +++ b/std/algebra/native/twistededwards/curve.go @@ -47,10 +47,7 @@ func (c *curve) AssertIsOnCurve(p1 Point) { func (c *curve) ScalarMul(p1 Point, scalar frontend.Variable) Point { var p Point if c.endo != nil { - // TODO restore - // this is disabled until this issue is solved https://github.com/ConsenSys/gnark/issues/268 - // p.scalarMulGLV(c.api, &p1, scalar, c.params, c.endo) - p.scalarMul(c.api, &p1, scalar, c.params) + p.scalarMulGLV(c.api, &p1, scalar, c.params, c.endo) } else { p.scalarMul(c.api, &p1, scalar, c.params) } diff --git a/std/algebra/native/twistededwards/hints.go b/std/algebra/native/twistededwards/hints.go new file mode 100644 index 0000000000..d30acfaad9 --- /dev/null +++ b/std/algebra/native/twistededwards/hints.go @@ -0,0 +1,105 @@ +package twistededwards + +import ( + "fmt" + "math/big" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/std/math/emulated/emparams" +) + +func GetHints() []solver.Hint { + return []solver.Hint{ + decomposeScalar, + decompose, + } +} + +func init() { + solver.RegisterHint(GetHints()...) +} + +type glvParams struct { + lambda, order big.Int + glvBasis ecc.Lattice +} + +func decomposeScalar(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { + return emulated.UnwrapHintWithNativeInput(nativeInputs, nativeOutputs, func(nnMod *big.Int, nninputs, nnOutputs []*big.Int) error { + if len(nninputs) != 1 { + return fmt.Errorf("expecting one input") + } + if len(nnOutputs) != 2 { + return fmt.Errorf("expecting two outputs") + } + var glv glvParams + glv.lambda.SetString("8913659658109529928382530854484400854125314752504019737736543920008458395397", 10) + glv.order.SetString("13108968793781547619861935127046491459309155893440570251786403306729687672801", 10) + ecc.PrecomputeLattice(&glv.order, &glv.lambda, &glv.glvBasis) + sp := ecc.SplitScalar(nninputs[0], &glv.glvBasis) + nnOutputs[0].Set(&(sp[0])) + nnOutputs[1].Neg(&(sp[1])) + + return nil + }) +} + +func callDecomposeScalar(api frontend.API, s frontend.Variable) (s1, s2 frontend.Variable) { + var fr emparams.BandersnatchFr + var glv glvParams + glv.lambda.SetString("8913659658109529928382530854484400854125314752504019737736543920008458395397", 10) + + sapi, err := emulated.NewField[emparams.BandersnatchFr](api) + if err != nil { + panic(err) + } + // compute the decomposition using a hint. We have to use the emulated + // version which takes native input and outputs non-native outputs. + // + // the hints allow to decompose the scalar s into s1 and s2 such that + // s1 + λ * s2 == s mod r, + // where λ is third root of one in 𝔽_r. + sd, err := sapi.NewHintWithNativeInput(decomposeScalar, 2, s) + if err != nil { + panic(err) + } + // lambda as nonnative element + lambdaEmu := sapi.NewElement(glv.lambda) + // the scalar as nonnative element. We need to split at 64 bits. + limbs, err := api.NewHint(decompose, int(fr.NbLimbs()), s) + if err != nil { + panic(err) + } + semu := sapi.NewElement(limbs) + // we negated s2 in decomposeScalar so we check instead: + // s + λ * s2 == s1 mod r + rhs := sapi.MulNoReduce(sd[1], lambdaEmu) + rhs = sapi.Add(rhs, semu) + sapi.AssertIsEqual(rhs, sd[0]) + + s1 = 0 + s2 = 0 + b := big.NewInt(1) + for i := range sd[0].Limbs { + s1 = api.Add(s1, api.Mul(sd[0].Limbs[i], b)) + s2 = api.Add(s2, api.Mul(sd[1].Limbs[i], b)) + b.Lsh(b, 64) + } + return s1, s2 +} + +func decompose(mod *big.Int, inputs, outputs []*big.Int) error { + if len(inputs) != 1 && len(outputs) != 4 { + return fmt.Errorf("input/output length mismatch") + } + tmp := new(big.Int).Set(inputs[0]) + mask := new(big.Int).SetUint64(^uint64(0)) + for i := 0; i < 4; i++ { + outputs[i].And(tmp, mask) + tmp.Rsh(tmp, 64) + } + return nil +} diff --git a/std/algebra/native/twistededwards/scalarmul_glv.go b/std/algebra/native/twistededwards/scalarmul_glv.go index 7b959a2db4..19b0fb14b9 100644 --- a/std/algebra/native/twistededwards/scalarmul_glv.go +++ b/std/algebra/native/twistededwards/scalarmul_glv.go @@ -16,15 +16,7 @@ limitations under the License. package twistededwards -import ( - "errors" - "math/big" - "sync" - - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/constraint/solver" - "github.com/consensys/gnark/frontend" -) +import "github.com/consensys/gnark/frontend" // phi endomorphism √-2 ∈ 𝒪₋₈ // (x,y) → λ × (x,y) s.t. λ² = -2 mod Order @@ -44,47 +36,6 @@ func (p *Point) phi(api frontend.API, p1 *Point, curve *CurveParams, endo *EndoP return p } -type glvParams struct { - lambda, order big.Int - glvBasis ecc.Lattice -} - -var DecomposeScalar = func(scalarField *big.Int, inputs []*big.Int, res []*big.Int) error { - // the efficient endomorphism exists on Bandersnatch only - if scalarField.Cmp(ecc.BLS12_381.ScalarField()) != 0 { - return errors.New("no efficient endomorphism is available on this curve") - } - var glv glvParams - var init sync.Once - init.Do(func() { - glv.lambda.SetString("8913659658109529928382530854484400854125314752504019737736543920008458395397", 10) - glv.order.SetString("13108968793781547619861935127046491459309155893440570251786403306729687672801", 10) - ecc.PrecomputeLattice(&glv.order, &glv.lambda, &glv.glvBasis) - }) - - // sp[0] is always negative because, in SplitScalar(), we always round above - // the determinant/2 computed in PrecomputeLattice() which is negative for Bandersnatch. - // Thus taking -sp[0] here and negating the point in ScalarMul(). - // If we keep -sp[0] it will be reduced mod r (the BLS12-381 prime order) - // and not the Bandersnatch prime order (Order) and the result will be incorrect. - // Also, if we reduce it mod Order here, we can't use api.ToBinary(sp[0], 129) - // and hence we can't reduce optimally the number of constraints. - sp := ecc.SplitScalar(inputs[0], &glv.glvBasis) - res[0].Neg(&(sp[0])) - res[1].Set(&(sp[1])) - - // figure out how many times we have overflowed - res[2].Mul(res[1], &glv.lambda).Sub(res[2], res[0]) - res[2].Sub(res[2], inputs[0]) - res[2].Div(res[2], &glv.order) - - return nil -} - -func init() { - solver.RegisterHint(DecomposeScalar) -} - // ScalarMul computes the scalar multiplication of a point on a twisted Edwards curve // p1: base point (as snark point) // curve: parameters of the Edwards curve @@ -94,37 +45,24 @@ func (p *Point) scalarMulGLV(api frontend.API, p1 *Point, scalar frontend.Variab // the hints allow to decompose the scalar s into s1 and s2 such that // s1 + λ * s2 == s mod Order, // with λ s.t. λ² = -2 mod Order. - sd, err := api.NewHint(DecomposeScalar, 3, scalar) - if err != nil { - // err is non-nil only for invalid number of inputs - panic(err) - } - - s1, s2 := sd[0], sd[1] - - // -s1 + λ * s2 == s + k*Order - api.AssertIsEqual(api.Sub(api.Mul(s2, endo.Lambda), s1), api.Add(scalar, api.Mul(curve.Order, sd[2]))) + s1, s2 := callDecomposeScalar(api, scalar) - // Normally s1 and s2 are of the max size sqrt(Order) = 128 - // But in a circuit, we force s1 to be negative by rounding always above. - // This changes the size bounds to 2*sqrt(Order) = 129. - n := 129 + n := 127 b1 := api.ToBinary(s1, n) b2 := api.ToBinary(s2, n) - var res, _p1, p2, p3, tmp Point - _p1.neg(api, p1) - p2.phi(api, p1, curve, endo) - p3.add(api, &_p1, &p2, curve) + var res, p2, p3, tmp Point + p2.phi(api, p1, curve, endo).neg(api, &p2) + p3.add(api, p1, &p2, curve) - res.X = api.Lookup2(b1[n-1], b2[n-1], 0, _p1.X, p2.X, p3.X) - res.Y = api.Lookup2(b1[n-1], b2[n-1], 1, _p1.Y, p2.Y, p3.Y) + res.X = api.Lookup2(b1[n-1], b2[n-1], 0, p1.X, p2.X, p3.X) + res.Y = api.Lookup2(b1[n-1], b2[n-1], 1, p1.Y, p2.Y, p3.Y) for i := n - 2; i >= 0; i-- { res.double(api, &res, curve) - tmp.X = api.Lookup2(b1[i], b2[i], 0, _p1.X, p2.X, p3.X) - tmp.Y = api.Lookup2(b1[i], b2[i], 1, _p1.Y, p2.Y, p3.Y) + tmp.X = api.Lookup2(b1[i], b2[i], 0, p1.X, p2.X, p3.X) + tmp.Y = api.Lookup2(b1[i], b2[i], 1, p1.Y, p2.Y, p3.Y) res.add(api, &res, &tmp, curve) } diff --git a/std/math/emulated/emparams/emparams.go b/std/math/emulated/emparams/emparams.go index 55d1008cf1..b61f6930e0 100644 --- a/std/math/emulated/emparams/emparams.go +++ b/std/math/emulated/emparams/emparams.go @@ -282,6 +282,38 @@ type BLS24315Fr struct{ fourLimbPrimeField } func (fr BLS24315Fr) Modulus() *big.Int { return ecc.BLS24_315.ScalarField() } +// BandersnatchFp provides type parametrization for field emulation: +// - limbs: 4 +// - limb width: 64 bits +// +// The prime modulus for type parametrisation is: +// +// 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001 (base 16) +// 52435875175126190479447740508185965837690552500527637822603658699938581184513 (base 10) +// +// This is the base field of the Bandersnatch curve. +type BandersnatchFp struct{ fourLimbPrimeField } + +func (fp BandersnatchFp) Modulus() *big.Int { return ecc.BLS12_381.ScalarField() } + +// BandersnatchFr provides type parametrization for field emulation: +// - limbs: 4 +// - limb width: 64 bits +// +// The prime modulus for type parametrisation is: +// +// 0x1cfb69d4ca675f520cce760202687600ff8f87007419047174fd06b52876e7e1 (base 16) +// 13108968793781547619861935127046491459309155893440570251786403306729687672801 (base 10) +// +// This is the scalar field of the Bandersnatch curve. +type BandersnatchFr struct{ fourLimbPrimeField } + +func (fp BandersnatchFr) Modulus() *big.Int { + var scalarField big.Int + scalarField.SetString("1cfb69d4ca675f520cce760202687600ff8f87007419047174fd06b52876e7e1", 16) + return &scalarField +} + // Mod1e4096 provides type parametrization for emulated arithmetic: // - limbs: 64 // - limb width: 64 bits From d74fced2e0277aced6d005e2c685ce8368b44f0f Mon Sep 17 00:00:00 2001 From: Youssef El Housni Date: Thu, 5 Sep 2024 20:05:42 -0400 Subject: [PATCH 2/5] fix: s2 can be negative sometimes --- std/algebra/native/twistededwards/hints.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/std/algebra/native/twistededwards/hints.go b/std/algebra/native/twistededwards/hints.go index d30acfaad9..3a982acc0b 100644 --- a/std/algebra/native/twistededwards/hints.go +++ b/std/algebra/native/twistededwards/hints.go @@ -43,6 +43,11 @@ func decomposeScalar(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) nnOutputs[0].Set(&(sp[0])) nnOutputs[1].Neg(&(sp[1])) + // TODO: @yelhousni handle negative s2 + if nnOutputs[1].Sign() == -1 { + panic(fmt.Errorf("negative s2 not handled yet")) + } + return nil }) } From 31738b63e5dffb4e78599428c37760126e402ae1 Mon Sep 17 00:00:00 2001 From: Youssef El Housni Date: Fri, 6 Sep 2024 10:42:51 -0400 Subject: [PATCH 3/5] fix: s2 negative case --- std/algebra/native/twistededwards/hints.go | 44 ++++++++--- std/algebra/native/twistededwards/point.go | 57 +++++++++++++++ .../native/twistededwards/scalarmul_glv.go | 73 ------------------- 3 files changed, 92 insertions(+), 82 deletions(-) delete mode 100644 std/algebra/native/twistededwards/scalarmul_glv.go diff --git a/std/algebra/native/twistededwards/hints.go b/std/algebra/native/twistededwards/hints.go index 3a982acc0b..51ab40bb08 100644 --- a/std/algebra/native/twistededwards/hints.go +++ b/std/algebra/native/twistededwards/hints.go @@ -14,6 +14,7 @@ import ( func GetHints() []solver.Hint { return []solver.Hint{ decomposeScalar, + decomposeScalarSigns, decompose, } } @@ -41,18 +42,37 @@ func decomposeScalar(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) ecc.PrecomputeLattice(&glv.order, &glv.lambda, &glv.glvBasis) sp := ecc.SplitScalar(nninputs[0], &glv.glvBasis) nnOutputs[0].Set(&(sp[0])) - nnOutputs[1].Neg(&(sp[1])) + nnOutputs[1].Set(&(sp[1])) - // TODO: @yelhousni handle negative s2 if nnOutputs[1].Sign() == -1 { - panic(fmt.Errorf("negative s2 not handled yet")) + nnOutputs[1].Neg(nnOutputs[1]) } return nil }) } -func callDecomposeScalar(api frontend.API, s frontend.Variable) (s1, s2 frontend.Variable) { +func decomposeScalarSigns(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { + if len(inputs) != 1 { + return fmt.Errorf("expecting one input") + } + if len(outputs) != 1 { + return fmt.Errorf("expecting one output") + } + var glv glvParams + glv.lambda.SetString("8913659658109529928382530854484400854125314752504019737736543920008458395397", 10) + glv.order.SetString("13108968793781547619861935127046491459309155893440570251786403306729687672801", 10) + ecc.PrecomputeLattice(&glv.order, &glv.lambda, &glv.glvBasis) + sp := ecc.SplitScalar(inputs[0], &glv.glvBasis) + outputs[0].SetUint64(0) + if sp[1].Sign() == -1 { + outputs[0].SetUint64(1) + } + + return nil +} + +func callDecomposeScalar(api frontend.API, s frontend.Variable) (s1, s2, s3 frontend.Variable) { var fr emparams.BandersnatchFr var glv glvParams glv.lambda.SetString("8913659658109529928382530854484400854125314752504019737736543920008458395397", 10) @@ -71,6 +91,10 @@ func callDecomposeScalar(api frontend.API, s frontend.Variable) (s1, s2 frontend if err != nil { panic(err) } + bit, err := api.NewHint(decomposeScalarSigns, 1, s) + if err != nil { + panic(err) + } // lambda as nonnative element lambdaEmu := sapi.NewElement(glv.lambda) // the scalar as nonnative element. We need to split at 64 bits. @@ -80,20 +104,22 @@ func callDecomposeScalar(api frontend.API, s frontend.Variable) (s1, s2 frontend } semu := sapi.NewElement(limbs) // we negated s2 in decomposeScalar so we check instead: - // s + λ * s2 == s1 mod r - rhs := sapi.MulNoReduce(sd[1], lambdaEmu) - rhs = sapi.Add(rhs, semu) - sapi.AssertIsEqual(rhs, sd[0]) + // s1 + λ * s2 == s mod r + _s1 := sapi.Select(bit[0], sapi.Neg(sd[1]), sd[1]) + rhs := sapi.MulNoReduce(_s1, lambdaEmu) + rhs = sapi.Add(rhs, sd[0]) + sapi.AssertIsEqual(rhs, semu) s1 = 0 s2 = 0 + s3 = bit[0] b := big.NewInt(1) for i := range sd[0].Limbs { s1 = api.Add(s1, api.Mul(sd[0].Limbs[i], b)) s2 = api.Add(s2, api.Mul(sd[1].Limbs[i], b)) b.Lsh(b, 64) } - return s1, s2 + return s1, s2, s3 } func decompose(mod *big.Int, inputs, outputs []*big.Int) error { diff --git a/std/algebra/native/twistededwards/point.go b/std/algebra/native/twistededwards/point.go index dbacdb30d5..c18dd8701a 100644 --- a/std/algebra/native/twistededwards/point.go +++ b/std/algebra/native/twistededwards/point.go @@ -172,3 +172,60 @@ func (p *Point) doubleBaseScalarMul(api frontend.API, p1, p2 *Point, s1, s2 fron return p } + +// GLV scalar multiplication + +// phi endomorphism √-2 ∈ 𝒪₋₈ +// (x,y) → λ × (x,y) s.t. λ² = -2 mod Order +func (p *Point) phi(api frontend.API, p1 *Point, curve *CurveParams, endo *EndoParams) *Point { + + xy := api.Mul(p1.X, p1.Y) + yy := api.Mul(p1.Y, p1.Y) + f := api.Sub(1, yy) + f = api.Mul(f, endo.Endo[1]) + g := api.Add(yy, endo.Endo[0]) + g = api.Mul(g, endo.Endo[0]) + h := api.Sub(yy, endo.Endo[0]) + + p.X = api.DivUnchecked(f, xy) + p.Y = api.DivUnchecked(g, h) + + return p +} + +// ScalarMul computes the scalar multiplication of a point on a twisted Edwards curve +// p1: base point (as snark point) +// curve: parameters of the Edwards curve +// scal: scalar as a SNARK constraint +// Standard left to right double and add +func (p *Point) scalarMulGLV(api frontend.API, p1 *Point, scalar frontend.Variable, curve *CurveParams, endo *EndoParams) *Point { + // the hints allow to decompose the scalar s into s1 and s2 such that + // s1 + λ * s2 == s mod Order, + // with λ s.t. λ² = -2 mod Order. + s1, s2, s3 := callDecomposeScalar(api, scalar) + + n := 127 + + b1 := api.ToBinary(s1, n) + b2 := api.ToBinary(s2, n) + + var res, p2, p3, tmp Point + p2.phi(api, p1, curve, endo) + p2.X = api.Select(s3, api.Neg(p2.X), p2.X) + p3.add(api, p1, &p2, curve) + + res.X = api.Lookup2(b1[n-1], b2[n-1], 0, p1.X, p2.X, p3.X) + res.Y = api.Lookup2(b1[n-1], b2[n-1], 1, p1.Y, p2.Y, p3.Y) + + for i := n - 2; i >= 0; i-- { + res.double(api, &res, curve) + tmp.X = api.Lookup2(b1[i], b2[i], 0, p1.X, p2.X, p3.X) + tmp.Y = api.Lookup2(b1[i], b2[i], 1, p1.Y, p2.Y, p3.Y) + res.add(api, &res, &tmp, curve) + } + + p.X = res.X + p.Y = res.Y + + return p +} diff --git a/std/algebra/native/twistededwards/scalarmul_glv.go b/std/algebra/native/twistededwards/scalarmul_glv.go deleted file mode 100644 index 19b0fb14b9..0000000000 --- a/std/algebra/native/twistededwards/scalarmul_glv.go +++ /dev/null @@ -1,73 +0,0 @@ -/* -Copyright © 2022 ConsenSys Software Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package twistededwards - -import "github.com/consensys/gnark/frontend" - -// phi endomorphism √-2 ∈ 𝒪₋₈ -// (x,y) → λ × (x,y) s.t. λ² = -2 mod Order -func (p *Point) phi(api frontend.API, p1 *Point, curve *CurveParams, endo *EndoParams) *Point { - - xy := api.Mul(p1.X, p1.Y) - yy := api.Mul(p1.Y, p1.Y) - f := api.Sub(1, yy) - f = api.Mul(f, endo.Endo[1]) - g := api.Add(yy, endo.Endo[0]) - g = api.Mul(g, endo.Endo[0]) - h := api.Sub(yy, endo.Endo[0]) - - p.X = api.DivUnchecked(f, xy) - p.Y = api.DivUnchecked(g, h) - - return p -} - -// ScalarMul computes the scalar multiplication of a point on a twisted Edwards curve -// p1: base point (as snark point) -// curve: parameters of the Edwards curve -// scal: scalar as a SNARK constraint -// Standard left to right double and add -func (p *Point) scalarMulGLV(api frontend.API, p1 *Point, scalar frontend.Variable, curve *CurveParams, endo *EndoParams) *Point { - // the hints allow to decompose the scalar s into s1 and s2 such that - // s1 + λ * s2 == s mod Order, - // with λ s.t. λ² = -2 mod Order. - s1, s2 := callDecomposeScalar(api, scalar) - - n := 127 - - b1 := api.ToBinary(s1, n) - b2 := api.ToBinary(s2, n) - - var res, p2, p3, tmp Point - p2.phi(api, p1, curve, endo).neg(api, &p2) - p3.add(api, p1, &p2, curve) - - res.X = api.Lookup2(b1[n-1], b2[n-1], 0, p1.X, p2.X, p3.X) - res.Y = api.Lookup2(b1[n-1], b2[n-1], 1, p1.Y, p2.Y, p3.Y) - - for i := n - 2; i >= 0; i-- { - res.double(api, &res, curve) - tmp.X = api.Lookup2(b1[i], b2[i], 0, p1.X, p2.X, p3.X) - tmp.Y = api.Lookup2(b1[i], b2[i], 1, p1.Y, p2.Y, p3.Y) - res.add(api, &res, &tmp, curve) - } - - p.X = res.X - p.Y = res.Y - - return p -} From 5adde720845b875c81af97d22ed97052b788b095 Mon Sep 17 00:00:00 2001 From: Youssef El Housni Date: Fri, 6 Sep 2024 17:17:36 -0400 Subject: [PATCH 4/5] refactor: glv in tEdwards only in r1cs --- std/algebra/native/twistededwards/curve.go | 6 +--- std/algebra/native/twistededwards/point.go | 34 +++++++++++++++++----- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/std/algebra/native/twistededwards/curve.go b/std/algebra/native/twistededwards/curve.go index 491d4ee3de..9349e276e5 100644 --- a/std/algebra/native/twistededwards/curve.go +++ b/std/algebra/native/twistededwards/curve.go @@ -46,11 +46,7 @@ func (c *curve) AssertIsOnCurve(p1 Point) { } func (c *curve) ScalarMul(p1 Point, scalar frontend.Variable) Point { var p Point - if c.endo != nil { - p.scalarMulGLV(c.api, &p1, scalar, c.params, c.endo) - } else { - p.scalarMul(c.api, &p1, scalar, c.params) - } + p.scalarMul(c.api, &p1, scalar, c.params, c.endo) return p } func (c *curve) DoubleBaseScalarMul(p1, p2 Point, s1, s2 frontend.Variable) Point { diff --git a/std/algebra/native/twistededwards/point.go b/std/algebra/native/twistededwards/point.go index c18dd8701a..424c48e75c 100644 --- a/std/algebra/native/twistededwards/point.go +++ b/std/algebra/native/twistededwards/point.go @@ -18,6 +18,7 @@ package twistededwards import ( "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/frontendtype" ) // neg computes the negative of a point in SNARK coordinates @@ -95,17 +96,12 @@ func (p *Point) double(api frontend.API, p1 *Point, curve *CurveParams) *Point { return p } -// scalarMul computes the scalar multiplication of a point on a twisted Edwards curve +// scalarMulGeneric computes the scalar multiplication of a point on a twisted Edwards curve // p1: base point (as snark point) // curve: parameters of the Edwards curve // scal: scalar as a SNARK constraint // Standard left to right double and add -func (p *Point) scalarMul(api frontend.API, p1 *Point, scalar frontend.Variable, curve *CurveParams, endo ...*EndoParams) *Point { - if len(endo) == 1 && endo[0] != nil { - // use glv - return p.scalarMulGLV(api, p1, scalar, curve, endo[0]) - } - +func (p *Point) scalarMulGeneric(api frontend.API, p1 *Point, scalar frontend.Variable, curve *CurveParams, endo ...*EndoParams) *Point { // first unpack the scalar b := api.ToBinary(scalar) @@ -142,6 +138,28 @@ func (p *Point) scalarMul(api frontend.API, p1 *Point, scalar frontend.Variable, return p } +// scalarMul computes the scalar multiplication of a point on a twisted Edwards curve +// p1: base point (as snark point) +// curve: parameters of the Edwards curve +// scal: scalar as a SNARK constraint +// Standard left to right double and add +func (p *Point) scalarMul(api frontend.API, p1 *Point, scalar frontend.Variable, curve *CurveParams, endo ...*EndoParams) *Point { + if ft, ok := api.(frontendtype.FrontendTyper); ok { + switch ft.FrontendType() { + case frontendtype.R1CS: + if len(endo) == 1 && endo[0] != nil { + // use glv + return p.scalarMulGLV(api, p1, scalar, curve, endo[0]) + } else { + return p.scalarMulGeneric(api, p1, scalar, curve) + } + case frontendtype.SCS: + return p.scalarMulGeneric(api, p1, scalar, curve) + } + } + return p.scalarMulGeneric(api, p1, scalar, curve) +} + // doubleBaseScalarMul computes s1*P1+s2*P2 // where P1 and P2 are points on a twisted Edwards curve // and s1, s2 scalars. @@ -193,7 +211,7 @@ func (p *Point) phi(api frontend.API, p1 *Point, curve *CurveParams, endo *EndoP return p } -// ScalarMul computes the scalar multiplication of a point on a twisted Edwards curve +// scalarMulGLV computes the scalar multiplication of a point on a twisted Edwards curve à la GLV // p1: base point (as snark point) // curve: parameters of the Edwards curve // scal: scalar as a SNARK constraint From b2cb485edb07d864fe32196120b9e450bf28725d Mon Sep 17 00:00:00 2001 From: Youssef El Housni Date: Tue, 15 Oct 2024 14:16:23 -0400 Subject: [PATCH 5/5] fix: bandersnatch GLV edge case with 0-point --- .../native/twistededwards/curve_test.go | 28 +++++++++++++++++ std/algebra/native/twistededwards/point.go | 30 ++++++++++++------- 2 files changed, 47 insertions(+), 11 deletions(-) diff --git a/std/algebra/native/twistededwards/curve_test.go b/std/algebra/native/twistededwards/curve_test.go index b12a027c01..a2e7a36ad4 100644 --- a/std/algebra/native/twistededwards/curve_test.go +++ b/std/algebra/native/twistededwards/curve_test.go @@ -152,6 +152,34 @@ func (circuit *addCircuit) Define(api frontend.API) error { api.AssertIsEqual(res.Y, circuit.ScalarMulResult.Y) } + { + // scalar mul zero-scalar edge-case + res := curve.ScalarMul(circuit.P2, 0) + api.AssertIsEqual(res.X, 0) + api.AssertIsEqual(res.Y, 1) + } + + { + // scalar mul zero-point edge-case + res := curve.ScalarMul(Point{0, 1}, circuit.S2) + api.AssertIsEqual(res.X, 0) + api.AssertIsEqual(res.Y, 1) + } + + { + // scalar mul zero-scalar and zero-point edge-case + res := curve.ScalarMul(Point{0, 1}, 0) + api.AssertIsEqual(res.X, 0) + api.AssertIsEqual(res.Y, 1) + } + + { + // scalar mul one-scalar edge-case + res := curve.ScalarMul(circuit.P2, 1) + api.AssertIsEqual(res.X, circuit.P2.X) + api.AssertIsEqual(res.Y, circuit.P2.Y) + } + { // double scalar mul res := curve.DoubleBaseScalarMul(circuit.P1, circuit.P2, circuit.S1, circuit.S2) diff --git a/std/algebra/native/twistededwards/point.go b/std/algebra/native/twistededwards/point.go index 424c48e75c..cc9865e77e 100644 --- a/std/algebra/native/twistededwards/point.go +++ b/std/algebra/native/twistededwards/point.go @@ -101,7 +101,7 @@ func (p *Point) double(api frontend.API, p1 *Point, curve *CurveParams) *Point { // curve: parameters of the Edwards curve // scal: scalar as a SNARK constraint // Standard left to right double and add -func (p *Point) scalarMulGeneric(api frontend.API, p1 *Point, scalar frontend.Variable, curve *CurveParams, endo ...*EndoParams) *Point { +func (p *Point) scalarMulGeneric(api frontend.API, p1 *Point, scalar frontend.Variable, curve *CurveParams) *Point { // first unpack the scalar b := api.ToBinary(scalar) @@ -195,7 +195,7 @@ func (p *Point) doubleBaseScalarMul(api frontend.API, p1, p2 *Point, s1, s2 fron // phi endomorphism √-2 ∈ 𝒪₋₈ // (x,y) → λ × (x,y) s.t. λ² = -2 mod Order -func (p *Point) phi(api frontend.API, p1 *Point, curve *CurveParams, endo *EndoParams) *Point { +func (p *Point) phi(api frontend.API, p1 *Point, endo *EndoParams) *Point { xy := api.Mul(p1.X, p1.Y) yy := api.Mul(p1.Y, p1.Y) @@ -217,6 +217,7 @@ func (p *Point) phi(api frontend.API, p1 *Point, curve *CurveParams, endo *EndoP // scal: scalar as a SNARK constraint // Standard left to right double and add func (p *Point) scalarMulGLV(api frontend.API, p1 *Point, scalar frontend.Variable, curve *CurveParams, endo *EndoParams) *Point { + // the hints allow to decompose the scalar s into s1 and s2 such that // s1 + λ * s2 == s mod Order, // with λ s.t. λ² = -2 mod Order. @@ -227,23 +228,30 @@ func (p *Point) scalarMulGLV(api frontend.API, p1 *Point, scalar frontend.Variab b1 := api.ToBinary(s1, n) b2 := api.ToBinary(s2, n) - var res, p2, p3, tmp Point - p2.phi(api, p1, curve, endo) + var _p1, res, p2, p3, tmp Point + // the endomorphism is not defined for point with X=0 or Y=0 Y=0 points are + // not on the prime subgroup and X=0 point is the zero-point (0,1). + // So we replace p1=(0,1) with a dummy point (3,1) and continue at the end + // we return (0,1). + selector := api.IsZero(p1.X) + _p1.X = api.Select(selector, 3, p1.X) + _p1.Y = p1.Y + p2.phi(api, &_p1, endo) p2.X = api.Select(s3, api.Neg(p2.X), p2.X) - p3.add(api, p1, &p2, curve) + p3.add(api, &_p1, &p2, curve) - res.X = api.Lookup2(b1[n-1], b2[n-1], 0, p1.X, p2.X, p3.X) - res.Y = api.Lookup2(b1[n-1], b2[n-1], 1, p1.Y, p2.Y, p3.Y) + res.X = api.Lookup2(b1[n-1], b2[n-1], 0, _p1.X, p2.X, p3.X) + res.Y = api.Lookup2(b1[n-1], b2[n-1], 1, _p1.Y, p2.Y, p3.Y) for i := n - 2; i >= 0; i-- { res.double(api, &res, curve) - tmp.X = api.Lookup2(b1[i], b2[i], 0, p1.X, p2.X, p3.X) - tmp.Y = api.Lookup2(b1[i], b2[i], 1, p1.Y, p2.Y, p3.Y) + tmp.X = api.Lookup2(b1[i], b2[i], 0, _p1.X, p2.X, p3.X) + tmp.Y = api.Lookup2(b1[i], b2[i], 1, _p1.Y, p2.Y, p3.Y) res.add(api, &res, &tmp, curve) } - p.X = res.X - p.Y = res.Y + p.X = api.Select(selector, 0, res.X) + p.Y = api.Select(selector, 1, res.Y) return p }